mirror of
https://github.com/opencloud-eu/opencloud.git
synced 2026-05-24 22:19:09 -05:00
chore: bump reva
This commit is contained in:
committed by
Ralf Haferkamp
parent
393926bd73
commit
59bd11d02a
@@ -55,17 +55,17 @@ require (
|
||||
github.com/mitchellh/mapstructure v1.5.0
|
||||
github.com/mna/pigeon v1.3.0
|
||||
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826
|
||||
github.com/nats-io/nats-server/v2 v2.12.6
|
||||
github.com/nats-io/nats-server/v2 v2.14.0
|
||||
github.com/nats-io/nats.go v1.51.0
|
||||
github.com/oklog/run v1.2.0
|
||||
github.com/olekukonko/tablewriter v1.1.4
|
||||
github.com/onsi/ginkgo v1.16.5
|
||||
github.com/onsi/ginkgo/v2 v2.28.1
|
||||
github.com/onsi/gomega v1.39.1
|
||||
github.com/onsi/gomega v1.40.0
|
||||
github.com/open-policy-agent/opa v1.15.2
|
||||
github.com/opencloud-eu/icap-client v0.0.0-20250930132611-28a2afe62d89
|
||||
github.com/opencloud-eu/libre-graph-api-go v1.0.8-0.20260310090739-853d972b282d
|
||||
github.com/opencloud-eu/reva/v2 v2.43.1-0.20260424125411-c5db28365753
|
||||
github.com/opencloud-eu/reva/v2 v2.43.1-0.20260512061040-cd4be86c66b0
|
||||
github.com/opensearch-project/opensearch-go/v4 v4.6.0
|
||||
github.com/orcaman/concurrent-map v1.0.0
|
||||
github.com/pkg/errors v0.9.1
|
||||
@@ -103,14 +103,14 @@ require (
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.43.0
|
||||
go.opentelemetry.io/otel/sdk v1.43.0
|
||||
go.opentelemetry.io/otel/trace v1.43.0
|
||||
golang.org/x/crypto v0.49.0
|
||||
golang.org/x/crypto v0.50.0
|
||||
golang.org/x/exp v0.0.0-20250210185358-939b2ce775ac
|
||||
golang.org/x/image v0.38.0
|
||||
golang.org/x/net v0.52.0
|
||||
golang.org/x/oauth2 v0.36.0
|
||||
golang.org/x/sync v0.20.0
|
||||
golang.org/x/term v0.41.0
|
||||
golang.org/x/text v0.35.0
|
||||
golang.org/x/term v0.42.0
|
||||
golang.org/x/text v0.36.0
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9
|
||||
google.golang.org/grpc v1.80.0
|
||||
google.golang.org/protobuf v1.36.11
|
||||
@@ -122,7 +122,7 @@ require (
|
||||
|
||||
require (
|
||||
contrib.go.opencensus.io/exporter/prometheus v0.4.2 // indirect
|
||||
filippo.io/edwards25519 v1.1.1 // indirect
|
||||
filippo.io/edwards25519 v1.2.0 // indirect
|
||||
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect
|
||||
github.com/Azure/go-ntlmssp v0.1.1 // indirect
|
||||
github.com/BurntSushi/toml v1.6.0 // indirect
|
||||
@@ -136,7 +136,7 @@ require (
|
||||
github.com/ajg/form v1.5.1 // indirect
|
||||
github.com/alexedwards/argon2id v1.0.0 // indirect
|
||||
github.com/amoghe/go-crypt v0.0.0-20220222110647-20eada5f5964 // indirect
|
||||
github.com/antithesishq/antithesis-sdk-go v0.6.0-default-no-op // indirect
|
||||
github.com/antithesishq/antithesis-sdk-go v0.7.0-default-no-op // indirect
|
||||
github.com/armon/go-radix v1.0.0 // indirect
|
||||
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
@@ -220,7 +220,7 @@ require (
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-redis/redis/v8 v8.11.5 // indirect
|
||||
github.com/go-sql-driver/mysql v1.9.3 // indirect
|
||||
github.com/go-sql-driver/mysql v1.10.0 // indirect
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
|
||||
github.com/go-test/deep v1.1.0 // indirect
|
||||
@@ -286,7 +286,7 @@ require (
|
||||
github.com/miekg/dns v1.1.68 // indirect
|
||||
github.com/mileusna/useragent v1.3.5 // indirect
|
||||
github.com/minio/crc64nvme v1.1.1 // indirect
|
||||
github.com/minio/highwayhash v1.0.4-0.20251030100505-070ab1a87a76 // indirect
|
||||
github.com/minio/highwayhash v1.0.4 // indirect
|
||||
github.com/minio/md5-simd v1.1.2 // indirect
|
||||
github.com/minio/minio-go/v7 v7.0.99 // indirect
|
||||
github.com/mitchellh/copystructure v1.2.0 // indirect
|
||||
@@ -388,10 +388,10 @@ require (
|
||||
go.uber.org/zap v1.27.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/mod v0.33.0 // indirect
|
||||
golang.org/x/mod v0.34.0 // indirect
|
||||
golang.org/x/sys v0.43.0 // indirect
|
||||
golang.org/x/time v0.15.0 // indirect
|
||||
golang.org/x/tools v0.42.0 // indirect
|
||||
golang.org/x/tools v0.43.0 // indirect
|
||||
google.golang.org/genproto v0.0.0-20260128011058-8636f8732409 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260406210006-6f92a3bedf2d // indirect
|
||||
gopkg.in/cenkalti/backoff.v1 v1.1.0 // indirect
|
||||
|
||||
@@ -39,8 +39,8 @@ contrib.go.opencensus.io/exporter/prometheus v0.4.2/go.mod h1:dvEHbiKmgvbr5pjaF9
|
||||
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
||||
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
|
||||
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
||||
filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw=
|
||||
filippo.io/edwards25519 v1.1.1/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo=
|
||||
filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc=
|
||||
github.com/Acconut/go-httptest-recorder v1.0.0 h1:TAv2dfnqp/l+SUvIaMAUK4GeN4+wqb6KZsFFFTGhoJg=
|
||||
github.com/Acconut/go-httptest-recorder v1.0.0/go.mod h1:CwQyhTH1kq/gLyWiRieo7c0uokpu3PXeyF/nZjUNtmM=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
|
||||
@@ -117,8 +117,8 @@ github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNg
|
||||
github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8=
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
|
||||
github.com/antithesishq/antithesis-sdk-go v0.6.0-default-no-op h1:kpBdlEPbRvff0mDD1gk7o9BhI16b9p5yYAXRlidpqJE=
|
||||
github.com/antithesishq/antithesis-sdk-go v0.6.0-default-no-op/go.mod h1:IUpT2DPAKh6i/YhSbt6Gl3v2yvUZjmKncl7U91fup7E=
|
||||
github.com/antithesishq/antithesis-sdk-go v0.7.0-default-no-op h1:Z/MZK75wC/NSrkgqeNIa7jexam9uWzhLmFTSCPI/kn0=
|
||||
github.com/antithesishq/antithesis-sdk-go v0.7.0-default-no-op/go.mod h1:FQyySiasQQM8735Ddel3MRojmy4dA1IqCeyJ5jmPMbI=
|
||||
github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ=
|
||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
|
||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
|
||||
@@ -456,8 +456,8 @@ github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq
|
||||
github.com/go-resty/resty/v2 v2.1.1-0.20191201195748-d7b97669fe48/go.mod h1:dZGr0i9PLlaaTD4H/hoZIDjQ+r6xq8mgbRzHZf7f2J8=
|
||||
github.com/go-resty/resty/v2 v2.17.2 h1:FQW5oHYcIlkCNrMD2lloGScxcHJ0gkjshV3qcQAyHQk=
|
||||
github.com/go-resty/resty/v2 v2.17.2/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA=
|
||||
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
|
||||
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
|
||||
github.com/go-sql-driver/mysql v1.10.0 h1:Q+1LV8DkHJvSYAdR83XzuhDaTykuDx0l6fkXxoWCWfw=
|
||||
github.com/go-sql-driver/mysql v1.10.0/go.mod h1:M+cqaI7+xxXGG9swrdeUIoPG3Y3KCkF0pZej+SK+nWk=
|
||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
||||
@@ -839,8 +839,8 @@ github.com/mileusna/useragent v1.3.5 h1:SJM5NzBmh/hO+4LGeATKpaEX9+b4vcGg2qXGLiNG
|
||||
github.com/mileusna/useragent v1.3.5/go.mod h1:3d8TOmwL/5I8pJjyVDteHtgDGcefrFUX4ccGOMKNYYc=
|
||||
github.com/minio/crc64nvme v1.1.1 h1:8dwx/Pz49suywbO+auHCBpCtlW1OfpcLN7wYgVR6wAI=
|
||||
github.com/minio/crc64nvme v1.1.1/go.mod h1:eVfm2fAzLlxMdUGc0EEBGSMmPwmXD5XiNRpnu9J3bvg=
|
||||
github.com/minio/highwayhash v1.0.4-0.20251030100505-070ab1a87a76 h1:KGuD/pM2JpL9FAYvBrnBBeENKZNh6eNtjqytV6TYjnk=
|
||||
github.com/minio/highwayhash v1.0.4-0.20251030100505-070ab1a87a76/go.mod h1:GGYsuwP/fPD6Y9hMiXuapVvlIUEhFhMTh0rxU3ik1LQ=
|
||||
github.com/minio/highwayhash v1.0.4 h1:asJizugGgchQod2ja9NJlGOWq4s7KsAWr5XUc9Clgl4=
|
||||
github.com/minio/highwayhash v1.0.4/go.mod h1:GGYsuwP/fPD6Y9hMiXuapVvlIUEhFhMTh0rxU3ik1LQ=
|
||||
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
||||
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
||||
github.com/minio/minio-go/v7 v7.0.99 h1:2vH/byrwUkIpFQFOilvTfaUpvAX3fEFhEzO+DR3DlCE=
|
||||
@@ -900,8 +900,8 @@ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRW
|
||||
github.com/namedotcom/go v0.0.0-20180403034216-08470befbe04/go.mod h1:5sN+Lt1CaY4wsPvgQH/jsuJi4XO2ssZbdsIizr4CVC8=
|
||||
github.com/nats-io/jwt/v2 v2.8.1 h1:V0xpGuD/N8Mi+fQNDynXohVvp7ZztevW5io8CUWlPmU=
|
||||
github.com/nats-io/jwt/v2 v2.8.1/go.mod h1:nWnOEEiVMiKHQpnAy4eXlizVEtSfzacZ1Q43LIRavZg=
|
||||
github.com/nats-io/nats-server/v2 v2.12.6 h1:Egbx9Vl7Ch8wTtpXPGqbehkZ+IncKqShUxvrt1+Enc8=
|
||||
github.com/nats-io/nats-server/v2 v2.12.6/go.mod h1:4HPlrvtmSO3yd7KcElDNMx9kv5EBJBnJJzQPptXlheo=
|
||||
github.com/nats-io/nats-server/v2 v2.14.0 h1:+8q0HrDFotwLLcGH/legOEOnowunhK+aZ4GYBIWpQlM=
|
||||
github.com/nats-io/nats-server/v2 v2.14.0/go.mod h1:ImVUUDvfClJbb6cuJQRc1VmgDCXKM5ds0OoiG9MVOKo=
|
||||
github.com/nats-io/nats.go v1.51.0 h1:ByW84XTz6W03GSSsygsZcA+xgKK8vPGaa/FCAAEHnAI=
|
||||
github.com/nats-io/nats.go v1.51.0/go.mod h1:26HypzazeOkyO3/mqd1zZd53STJN0EjCYF9Uy2ZOBno=
|
||||
github.com/nats-io/nkeys v0.4.15 h1:JACV5jRVO9V856KOapQ7x+EY8Jo3qw1vJt/9Jpwzkk4=
|
||||
@@ -940,8 +940,8 @@ github.com/onsi/ginkgo/v2 v2.28.1/go.mod h1:CLtbVInNckU3/+gC8LzkGUb9oF+e8W8TdUsx
|
||||
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
|
||||
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
|
||||
github.com/onsi/gomega v1.39.1 h1:1IJLAad4zjPn2PsnhH70V4DKRFlrCzGBNrNaru+Vf28=
|
||||
github.com/onsi/gomega v1.39.1/go.mod h1:hL6yVALoTOxeWudERyfppUcZXjMwIMLnuSfruD2lcfg=
|
||||
github.com/onsi/gomega v1.40.0 h1:Vtol0e1MghCD2ZVIilPDIg44XSL9l2QAn8ZNaljWcJc=
|
||||
github.com/onsi/gomega v1.40.0/go.mod h1:M/Uqpu/8qTjtzCLUA2zJHX9Iilrau25x1PdoSRbWh5A=
|
||||
github.com/open-policy-agent/opa v1.15.2 h1:dS9q+0Yvruq/VNvWJc5qCvCchn715OWc3HLHXn/UCCc=
|
||||
github.com/open-policy-agent/opa v1.15.2/go.mod h1:c6SN+7jSsUcKJLQc5P4yhwx8YYDRbjpAiGkBOTqxaa4=
|
||||
github.com/opencloud-eu/go-micro-plugins/v4/store/nats-js-kv v0.0.0-20250512152754-23325793059a h1:Sakl76blJAaM6NxylVkgSzktjo2dS504iDotEFJsh3M=
|
||||
@@ -952,8 +952,8 @@ github.com/opencloud-eu/inotifywaitgo v0.0.0-20251111171128-a390bae3c5e9 h1:dIft
|
||||
github.com/opencloud-eu/inotifywaitgo v0.0.0-20251111171128-a390bae3c5e9/go.mod h1:JWyDC6H+5oZRdUJUgKuaye+8Ph5hEs6HVzVoPKzWSGI=
|
||||
github.com/opencloud-eu/libre-graph-api-go v1.0.8-0.20260310090739-853d972b282d h1:JcqGDiyrcaQwVyV861TUyQgO7uEmsjkhfm7aQd84dOw=
|
||||
github.com/opencloud-eu/libre-graph-api-go v1.0.8-0.20260310090739-853d972b282d/go.mod h1:pzatilMEHZFT3qV7C/X3MqOa3NlRQuYhlRhZTL+hN6Q=
|
||||
github.com/opencloud-eu/reva/v2 v2.43.1-0.20260424125411-c5db28365753 h1:/FpQdybaNb3OAISHmHRrh/4aWYQep3nVSYKzYt2F+jE=
|
||||
github.com/opencloud-eu/reva/v2 v2.43.1-0.20260424125411-c5db28365753/go.mod h1:msu4TkFw7Jxog0QRbGPxyQOJG9sago5nc+f//y+bbpI=
|
||||
github.com/opencloud-eu/reva/v2 v2.43.1-0.20260512061040-cd4be86c66b0 h1:e4w34sW1gXixTKi9z+odF6IKGyvisvu97xfYEXOvRGE=
|
||||
github.com/opencloud-eu/reva/v2 v2.43.1-0.20260512061040-cd4be86c66b0/go.mod h1:SoRYtNJ9ha83YdUUep5wYF7F5/OIhgED7ZSgqudhpNo=
|
||||
github.com/opencloud-eu/secure v0.0.0-20260312082735-b6f5cb2244e4 h1:l2oB/RctH+t8r7QBj5p8thfEHCM/jF35aAY3WQ3hADI=
|
||||
github.com/opencloud-eu/secure v0.0.0-20260312082735-b6f5cb2244e4/go.mod h1:BmF5hyM6tXczk3MpQkFf1hpKSRqCyhqcbiQtiAF7+40=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
@@ -1358,8 +1358,8 @@ golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
||||
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
@@ -1397,8 +1397,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
|
||||
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
|
||||
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
|
||||
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -1566,8 +1566,8 @@ golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
|
||||
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
|
||||
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
|
||||
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
||||
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
@@ -1580,8 +1580,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
@@ -1642,8 +1642,8 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f
|
||||
golang.org/x/tools v0.0.0-20210112230658-8b4aab62c064/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
|
||||
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
|
||||
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
|
||||
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
|
||||
golang.org/x/tools/godoc v0.1.0-deprecated h1:o+aZ1BOj6Hsx/GBdJO/s815sqftjSnrZZwyYTHODvtk=
|
||||
golang.org/x/tools/godoc v0.1.0-deprecated/go.mod h1:qM63CriJ961IHWmnWa9CjZnBndniPt4a3CK0PVB9bIg=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
+4
-2
@@ -7,8 +7,10 @@ import "filippo.io/edwards25519"
|
||||
This library implements the edwards25519 elliptic curve, exposing the necessary APIs to build a wide array of higher-level primitives.
|
||||
Read the docs at [pkg.go.dev/filippo.io/edwards25519](https://pkg.go.dev/filippo.io/edwards25519).
|
||||
|
||||
The code is originally derived from Adam Langley's internal implementation in the Go standard library, and includes George Tankersley's [performance improvements](https://golang.org/cl/71950). It was then further developed by Henry de Valence for use in ristretto255, and was finally [merged back into the Go standard library](https://golang.org/cl/276272) as of Go 1.17. It now tracks the upstream codebase and extends it with additional functionality.
|
||||
The package tracks the upstream standard library package `crypto/internal/fips140/edwards25519` and extends it with additional functionality.
|
||||
|
||||
Most users don't need this package, and should instead use `crypto/ed25519` for signatures, `golang.org/x/crypto/curve25519` for Diffie-Hellman, or `github.com/gtank/ristretto255` for prime order group logic. However, for anyone currently using a fork of `crypto/internal/edwards25519`/`crypto/ed25519/internal/edwards25519` or `github.com/agl/edwards25519`, this package should be a safer, faster, and more powerful alternative.
|
||||
The code is originally derived from Adam Langley's internal implementation in the Go standard library, and includes George Tankersley's [performance improvements](https://golang.org/cl/71950). It was then further developed by Henry de Valence for use in ristretto255, and was finally [merged back into the Go standard library](https://golang.org/cl/276272) as of Go 1.17.
|
||||
|
||||
Most users don't need this package, and should instead use `crypto/ed25519` for signatures, `crypto/ecdh` for Diffie-Hellman, or `github.com/gtank/ristretto255` for prime order group logic. However, for anyone currently using a fork of the internal `edwards25519` package or of `github.com/agl/edwards25519`, this package should be a safer, faster, and more powerful alternative.
|
||||
|
||||
Since this package is meant to curb proliferation of edwards25519 implementations in the Go ecosystem, it welcomes requests for new APIs or reviewable performance improvements.
|
||||
|
||||
+3
-3
@@ -10,11 +10,11 @@
|
||||
// the curve used by the Ed25519 signature scheme.
|
||||
//
|
||||
// Most users don't need this package, and should instead use crypto/ed25519 for
|
||||
// signatures, golang.org/x/crypto/curve25519 for Diffie-Hellman, or
|
||||
// github.com/gtank/ristretto255 for prime order group logic.
|
||||
// signatures, crypto/ecdh for Diffie-Hellman, or github.com/gtank/ristretto255
|
||||
// for prime order group logic.
|
||||
//
|
||||
// However, developers who do need to interact with low-level edwards25519
|
||||
// operations can use this package, which is an extended version of
|
||||
// crypto/internal/edwards25519 from the standard library repackaged as
|
||||
// crypto/internal/fips140/edwards25519 from the standard library repackaged as
|
||||
// an importable module.
|
||||
package edwards25519
|
||||
|
||||
+59
-8
@@ -9,6 +9,7 @@ package edwards25519
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"slices"
|
||||
|
||||
"filippo.io/edwards25519/field"
|
||||
)
|
||||
@@ -100,13 +101,15 @@ func (v *Point) bytesMontgomery(buf *[32]byte) []byte {
|
||||
//
|
||||
// u = (1 + y) / (1 - y)
|
||||
//
|
||||
// where y = Y / Z.
|
||||
// where y = Y / Z and therefore
|
||||
//
|
||||
// u = (Z + Y) / (Z - Y)
|
||||
|
||||
var y, recip, u field.Element
|
||||
var n, r, u field.Element
|
||||
|
||||
y.Multiply(&v.y, y.Invert(&v.z)) // y = Y / Z
|
||||
recip.Invert(recip.Subtract(feOne, &y)) // r = 1/(1 - y)
|
||||
u.Multiply(u.Add(feOne, &y), &recip) // u = (1 + y)*r
|
||||
n.Add(&v.z, &v.y) // n = Z + Y
|
||||
r.Invert(r.Subtract(&v.z, &v.y)) // r = 1 / (Z - Y)
|
||||
u.Multiply(&n, &r) // u = n * r
|
||||
|
||||
return copyFieldElement(buf, &u)
|
||||
}
|
||||
@@ -124,7 +127,7 @@ func (v *Point) MultByCofactor(p *Point) *Point {
|
||||
return v.fromP1xP1(&result)
|
||||
}
|
||||
|
||||
// Given k > 0, set s = s**(2*i).
|
||||
// Given k > 0, set s = s**(2*k).
|
||||
func (s *Scalar) pow2k(k int) {
|
||||
for i := 0; i < k; i++ {
|
||||
s.Multiply(s, s)
|
||||
@@ -250,12 +253,14 @@ func (v *Point) MultiScalarMult(scalars []*Scalar, points []*Point) *Point {
|
||||
// between each point in the multiscalar equation.
|
||||
|
||||
// Build lookup tables for each point
|
||||
tables := make([]projLookupTable, len(points))
|
||||
tables := make([]projLookupTable, 0, 2) // avoid allocation for small sizes
|
||||
tables = slices.Grow(tables, len(points))[:len(points)]
|
||||
for i := range tables {
|
||||
tables[i].FromP3(points[i])
|
||||
}
|
||||
// Compute signed radix-16 digits for each scalar
|
||||
digits := make([][64]int8, len(scalars))
|
||||
digits := make([][64]int8, 0, 2) // avoid allocation for small sizes
|
||||
digits = slices.Grow(digits, len(scalars))[:len(scalars)]
|
||||
for i := range digits {
|
||||
digits[i] = scalars[i].signedRadix16()
|
||||
}
|
||||
@@ -348,3 +353,49 @@ func (v *Point) VarTimeMultiScalarMult(scalars []*Scalar, points []*Point) *Poin
|
||||
v.fromP2(tmp2)
|
||||
return v
|
||||
}
|
||||
|
||||
// Select sets v to a if cond == 1 and to b if cond == 0.
|
||||
func (v *Point) Select(a, b *Point, cond int) *Point {
|
||||
checkInitialized(a, b)
|
||||
v.x.Select(&a.x, &b.x, cond)
|
||||
v.y.Select(&a.y, &b.y, cond)
|
||||
v.z.Select(&a.z, &b.z, cond)
|
||||
v.t.Select(&a.t, &b.t, cond)
|
||||
return v
|
||||
}
|
||||
|
||||
// Double sets v = p + p, and returns v.
|
||||
func (v *Point) Double(p *Point) *Point {
|
||||
checkInitialized(p)
|
||||
|
||||
pp := new(projP2).FromP3(p)
|
||||
p1 := new(projP1xP1).Double(pp)
|
||||
return v.fromP1xP1(p1)
|
||||
}
|
||||
|
||||
func (v *Point) addCached(p *Point, qCached *projCached) *Point {
|
||||
result := new(projP1xP1).Add(p, qCached)
|
||||
return v.fromP1xP1(result)
|
||||
}
|
||||
|
||||
// ScalarMultSlow sets v = x * q, and returns v. It doesn't precompute a large
|
||||
// table, so it is considerably slower, but requires less memory.
|
||||
//
|
||||
// The scalar multiplication is done in constant time.
|
||||
func (v *Point) ScalarMultSlow(x *Scalar, q *Point) *Point {
|
||||
checkInitialized(q)
|
||||
|
||||
s := x.Bytes()
|
||||
qCached := new(projCached).FromP3(q)
|
||||
v.Set(NewIdentityPoint())
|
||||
t := new(Point)
|
||||
|
||||
for i := 255; i >= 0; i-- {
|
||||
v.Double(v)
|
||||
t.addCached(v, qCached)
|
||||
cond := (s[i/8] >> (i % 8)) & 1
|
||||
v.Select(t, v, int(cond))
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
+17
-17
@@ -90,11 +90,7 @@ func (v *Element) Add(a, b *Element) *Element {
|
||||
v.l2 = a.l2 + b.l2
|
||||
v.l3 = a.l3 + b.l3
|
||||
v.l4 = a.l4 + b.l4
|
||||
// Using the generic implementation here is actually faster than the
|
||||
// assembly. Probably because the body of this function is so simple that
|
||||
// the compiler can figure out better optimizations by inlining the carry
|
||||
// propagation.
|
||||
return v.carryPropagateGeneric()
|
||||
return v.carryPropagate()
|
||||
}
|
||||
|
||||
// Subtract sets v = a - b, and returns v.
|
||||
@@ -232,18 +228,22 @@ func (v *Element) bytes(out *[32]byte) []byte {
|
||||
t := *v
|
||||
t.reduce()
|
||||
|
||||
var buf [8]byte
|
||||
for i, l := range [5]uint64{t.l0, t.l1, t.l2, t.l3, t.l4} {
|
||||
bitsOffset := i * 51
|
||||
binary.LittleEndian.PutUint64(buf[:], l<<uint(bitsOffset%8))
|
||||
for i, bb := range buf {
|
||||
off := bitsOffset/8 + i
|
||||
if off >= len(out) {
|
||||
break
|
||||
}
|
||||
out[off] |= bb
|
||||
}
|
||||
}
|
||||
// Pack five 51-bit limbs into four 64-bit words:
|
||||
//
|
||||
// 255 204 153 102 51 0
|
||||
// ├──l4──┼──l3──┼──l2──┼──l1──┼──l0──┤
|
||||
// ├───u3───┼───u2───┼───u1───┼───u0───┤
|
||||
// 256 192 128 64 0
|
||||
|
||||
u0 := t.l1<<51 | t.l0
|
||||
u1 := t.l2<<(102-64) | t.l1>>(64-51)
|
||||
u2 := t.l3<<(153-128) | t.l2>>(128-102)
|
||||
u3 := t.l4<<(204-192) | t.l3>>(192-153)
|
||||
|
||||
binary.LittleEndian.PutUint64(out[0*8:], u0)
|
||||
binary.LittleEndian.PutUint64(out[1*8:], u1)
|
||||
binary.LittleEndian.PutUint64(out[2*8:], u2)
|
||||
binary.LittleEndian.PutUint64(out[3*8:], u3)
|
||||
|
||||
return out[:]
|
||||
}
|
||||
|
||||
+1
-2
@@ -1,7 +1,6 @@
|
||||
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
|
||||
|
||||
//go:build amd64 && gc && !purego
|
||||
// +build amd64,gc,!purego
|
||||
//go:build !purego
|
||||
|
||||
package field
|
||||
|
||||
|
||||
+111
-92
@@ -1,7 +1,6 @@
|
||||
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
|
||||
|
||||
//go:build amd64 && gc && !purego
|
||||
// +build amd64,gc,!purego
|
||||
//go:build !purego
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
@@ -17,32 +16,36 @@ TEXT ·feMul(SB), NOSPLIT, $0-24
|
||||
MOVQ DX, SI
|
||||
|
||||
// r0 += 19×a1×b4
|
||||
MOVQ 8(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 32(BX)
|
||||
ADDQ AX, DI
|
||||
ADCQ DX, SI
|
||||
MOVQ 8(CX), DX
|
||||
LEAQ (DX)(DX*8), AX
|
||||
LEAQ (DX)(AX*2), AX
|
||||
MULQ 32(BX)
|
||||
ADDQ AX, DI
|
||||
ADCQ DX, SI
|
||||
|
||||
// r0 += 19×a2×b3
|
||||
MOVQ 16(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 24(BX)
|
||||
ADDQ AX, DI
|
||||
ADCQ DX, SI
|
||||
MOVQ 16(CX), DX
|
||||
LEAQ (DX)(DX*8), AX
|
||||
LEAQ (DX)(AX*2), AX
|
||||
MULQ 24(BX)
|
||||
ADDQ AX, DI
|
||||
ADCQ DX, SI
|
||||
|
||||
// r0 += 19×a3×b2
|
||||
MOVQ 24(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 16(BX)
|
||||
ADDQ AX, DI
|
||||
ADCQ DX, SI
|
||||
MOVQ 24(CX), DX
|
||||
LEAQ (DX)(DX*8), AX
|
||||
LEAQ (DX)(AX*2), AX
|
||||
MULQ 16(BX)
|
||||
ADDQ AX, DI
|
||||
ADCQ DX, SI
|
||||
|
||||
// r0 += 19×a4×b1
|
||||
MOVQ 32(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 8(BX)
|
||||
ADDQ AX, DI
|
||||
ADCQ DX, SI
|
||||
MOVQ 32(CX), DX
|
||||
LEAQ (DX)(DX*8), AX
|
||||
LEAQ (DX)(AX*2), AX
|
||||
MULQ 8(BX)
|
||||
ADDQ AX, DI
|
||||
ADCQ DX, SI
|
||||
|
||||
// r1 = a0×b1
|
||||
MOVQ (CX), AX
|
||||
@@ -57,25 +60,28 @@ TEXT ·feMul(SB), NOSPLIT, $0-24
|
||||
ADCQ DX, R8
|
||||
|
||||
// r1 += 19×a2×b4
|
||||
MOVQ 16(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 32(BX)
|
||||
ADDQ AX, R9
|
||||
ADCQ DX, R8
|
||||
MOVQ 16(CX), DX
|
||||
LEAQ (DX)(DX*8), AX
|
||||
LEAQ (DX)(AX*2), AX
|
||||
MULQ 32(BX)
|
||||
ADDQ AX, R9
|
||||
ADCQ DX, R8
|
||||
|
||||
// r1 += 19×a3×b3
|
||||
MOVQ 24(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 24(BX)
|
||||
ADDQ AX, R9
|
||||
ADCQ DX, R8
|
||||
MOVQ 24(CX), DX
|
||||
LEAQ (DX)(DX*8), AX
|
||||
LEAQ (DX)(AX*2), AX
|
||||
MULQ 24(BX)
|
||||
ADDQ AX, R9
|
||||
ADCQ DX, R8
|
||||
|
||||
// r1 += 19×a4×b2
|
||||
MOVQ 32(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 16(BX)
|
||||
ADDQ AX, R9
|
||||
ADCQ DX, R8
|
||||
MOVQ 32(CX), DX
|
||||
LEAQ (DX)(DX*8), AX
|
||||
LEAQ (DX)(AX*2), AX
|
||||
MULQ 16(BX)
|
||||
ADDQ AX, R9
|
||||
ADCQ DX, R8
|
||||
|
||||
// r2 = a0×b2
|
||||
MOVQ (CX), AX
|
||||
@@ -96,18 +102,20 @@ TEXT ·feMul(SB), NOSPLIT, $0-24
|
||||
ADCQ DX, R10
|
||||
|
||||
// r2 += 19×a3×b4
|
||||
MOVQ 24(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 32(BX)
|
||||
ADDQ AX, R11
|
||||
ADCQ DX, R10
|
||||
MOVQ 24(CX), DX
|
||||
LEAQ (DX)(DX*8), AX
|
||||
LEAQ (DX)(AX*2), AX
|
||||
MULQ 32(BX)
|
||||
ADDQ AX, R11
|
||||
ADCQ DX, R10
|
||||
|
||||
// r2 += 19×a4×b3
|
||||
MOVQ 32(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 24(BX)
|
||||
ADDQ AX, R11
|
||||
ADCQ DX, R10
|
||||
MOVQ 32(CX), DX
|
||||
LEAQ (DX)(DX*8), AX
|
||||
LEAQ (DX)(AX*2), AX
|
||||
MULQ 24(BX)
|
||||
ADDQ AX, R11
|
||||
ADCQ DX, R10
|
||||
|
||||
// r3 = a0×b3
|
||||
MOVQ (CX), AX
|
||||
@@ -134,11 +142,12 @@ TEXT ·feMul(SB), NOSPLIT, $0-24
|
||||
ADCQ DX, R12
|
||||
|
||||
// r3 += 19×a4×b4
|
||||
MOVQ 32(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 32(BX)
|
||||
ADDQ AX, R13
|
||||
ADCQ DX, R12
|
||||
MOVQ 32(CX), DX
|
||||
LEAQ (DX)(DX*8), AX
|
||||
LEAQ (DX)(AX*2), AX
|
||||
MULQ 32(BX)
|
||||
ADDQ AX, R13
|
||||
ADCQ DX, R12
|
||||
|
||||
// r4 = a0×b4
|
||||
MOVQ (CX), AX
|
||||
@@ -232,18 +241,22 @@ TEXT ·feSquare(SB), NOSPLIT, $0-16
|
||||
MOVQ DX, BX
|
||||
|
||||
// r0 += 38×l1×l4
|
||||
MOVQ 8(CX), AX
|
||||
IMUL3Q $0x26, AX, AX
|
||||
MULQ 32(CX)
|
||||
ADDQ AX, SI
|
||||
ADCQ DX, BX
|
||||
MOVQ 8(CX), DX
|
||||
LEAQ (DX)(DX*8), AX
|
||||
LEAQ (DX)(AX*2), AX
|
||||
SHLQ $0x01, AX
|
||||
MULQ 32(CX)
|
||||
ADDQ AX, SI
|
||||
ADCQ DX, BX
|
||||
|
||||
// r0 += 38×l2×l3
|
||||
MOVQ 16(CX), AX
|
||||
IMUL3Q $0x26, AX, AX
|
||||
MULQ 24(CX)
|
||||
ADDQ AX, SI
|
||||
ADCQ DX, BX
|
||||
MOVQ 16(CX), DX
|
||||
LEAQ (DX)(DX*8), AX
|
||||
LEAQ (DX)(AX*2), AX
|
||||
SHLQ $0x01, AX
|
||||
MULQ 24(CX)
|
||||
ADDQ AX, SI
|
||||
ADCQ DX, BX
|
||||
|
||||
// r1 = 2×l0×l1
|
||||
MOVQ (CX), AX
|
||||
@@ -253,18 +266,21 @@ TEXT ·feSquare(SB), NOSPLIT, $0-16
|
||||
MOVQ DX, DI
|
||||
|
||||
// r1 += 38×l2×l4
|
||||
MOVQ 16(CX), AX
|
||||
IMUL3Q $0x26, AX, AX
|
||||
MULQ 32(CX)
|
||||
ADDQ AX, R8
|
||||
ADCQ DX, DI
|
||||
MOVQ 16(CX), DX
|
||||
LEAQ (DX)(DX*8), AX
|
||||
LEAQ (DX)(AX*2), AX
|
||||
SHLQ $0x01, AX
|
||||
MULQ 32(CX)
|
||||
ADDQ AX, R8
|
||||
ADCQ DX, DI
|
||||
|
||||
// r1 += 19×l3×l3
|
||||
MOVQ 24(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 24(CX)
|
||||
ADDQ AX, R8
|
||||
ADCQ DX, DI
|
||||
MOVQ 24(CX), DX
|
||||
LEAQ (DX)(DX*8), AX
|
||||
LEAQ (DX)(AX*2), AX
|
||||
MULQ 24(CX)
|
||||
ADDQ AX, R8
|
||||
ADCQ DX, DI
|
||||
|
||||
// r2 = 2×l0×l2
|
||||
MOVQ (CX), AX
|
||||
@@ -280,11 +296,13 @@ TEXT ·feSquare(SB), NOSPLIT, $0-16
|
||||
ADCQ DX, R9
|
||||
|
||||
// r2 += 38×l3×l4
|
||||
MOVQ 24(CX), AX
|
||||
IMUL3Q $0x26, AX, AX
|
||||
MULQ 32(CX)
|
||||
ADDQ AX, R10
|
||||
ADCQ DX, R9
|
||||
MOVQ 24(CX), DX
|
||||
LEAQ (DX)(DX*8), AX
|
||||
LEAQ (DX)(AX*2), AX
|
||||
SHLQ $0x01, AX
|
||||
MULQ 32(CX)
|
||||
ADDQ AX, R10
|
||||
ADCQ DX, R9
|
||||
|
||||
// r3 = 2×l0×l3
|
||||
MOVQ (CX), AX
|
||||
@@ -294,18 +312,19 @@ TEXT ·feSquare(SB), NOSPLIT, $0-16
|
||||
MOVQ DX, R11
|
||||
|
||||
// r3 += 2×l1×l2
|
||||
MOVQ 8(CX), AX
|
||||
IMUL3Q $0x02, AX, AX
|
||||
MULQ 16(CX)
|
||||
ADDQ AX, R12
|
||||
ADCQ DX, R11
|
||||
MOVQ 8(CX), AX
|
||||
SHLQ $0x01, AX
|
||||
MULQ 16(CX)
|
||||
ADDQ AX, R12
|
||||
ADCQ DX, R11
|
||||
|
||||
// r3 += 19×l4×l4
|
||||
MOVQ 32(CX), AX
|
||||
IMUL3Q $0x13, AX, AX
|
||||
MULQ 32(CX)
|
||||
ADDQ AX, R12
|
||||
ADCQ DX, R11
|
||||
MOVQ 32(CX), DX
|
||||
LEAQ (DX)(DX*8), AX
|
||||
LEAQ (DX)(AX*2), AX
|
||||
MULQ 32(CX)
|
||||
ADDQ AX, R12
|
||||
ADCQ DX, R11
|
||||
|
||||
// r4 = 2×l0×l4
|
||||
MOVQ (CX), AX
|
||||
@@ -315,11 +334,11 @@ TEXT ·feSquare(SB), NOSPLIT, $0-16
|
||||
MOVQ DX, R13
|
||||
|
||||
// r4 += 2×l1×l3
|
||||
MOVQ 8(CX), AX
|
||||
IMUL3Q $0x02, AX, AX
|
||||
MULQ 24(CX)
|
||||
ADDQ AX, R14
|
||||
ADCQ DX, R13
|
||||
MOVQ 8(CX), AX
|
||||
SHLQ $0x01, AX
|
||||
MULQ 24(CX)
|
||||
ADDQ AX, R14
|
||||
ADCQ DX, R13
|
||||
|
||||
// r4 += l2×l2
|
||||
MOVQ 16(CX), AX
|
||||
|
||||
+1
-2
@@ -2,8 +2,7 @@
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build !amd64 || !gc || purego
|
||||
// +build !amd64 !gc purego
|
||||
//go:build !amd64 || purego
|
||||
|
||||
package field
|
||||
|
||||
|
||||
-16
@@ -1,16 +0,0 @@
|
||||
// Copyright (c) 2020 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build arm64 && gc && !purego
|
||||
// +build arm64,gc,!purego
|
||||
|
||||
package field
|
||||
|
||||
//go:noescape
|
||||
func carryPropagate(v *Element)
|
||||
|
||||
func (v *Element) carryPropagate() *Element {
|
||||
carryPropagate(v)
|
||||
return v
|
||||
}
|
||||
-42
@@ -1,42 +0,0 @@
|
||||
// Copyright (c) 2020 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build arm64 && gc && !purego
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
// carryPropagate works exactly like carryPropagateGeneric and uses the
|
||||
// same AND, ADD, and LSR+MADD instructions emitted by the compiler, but
|
||||
// avoids loading R0-R4 twice and uses LDP and STP.
|
||||
//
|
||||
// See https://golang.org/issues/43145 for the main compiler issue.
|
||||
//
|
||||
// func carryPropagate(v *Element)
|
||||
TEXT ·carryPropagate(SB),NOFRAME|NOSPLIT,$0-8
|
||||
MOVD v+0(FP), R20
|
||||
|
||||
LDP 0(R20), (R0, R1)
|
||||
LDP 16(R20), (R2, R3)
|
||||
MOVD 32(R20), R4
|
||||
|
||||
AND $0x7ffffffffffff, R0, R10
|
||||
AND $0x7ffffffffffff, R1, R11
|
||||
AND $0x7ffffffffffff, R2, R12
|
||||
AND $0x7ffffffffffff, R3, R13
|
||||
AND $0x7ffffffffffff, R4, R14
|
||||
|
||||
ADD R0>>51, R11, R11
|
||||
ADD R1>>51, R12, R12
|
||||
ADD R2>>51, R13, R13
|
||||
ADD R3>>51, R14, R14
|
||||
// R4>>51 * 19 + R10 -> R10
|
||||
LSR $51, R4, R21
|
||||
MOVD $19, R22
|
||||
MADD R22, R10, R21, R10
|
||||
|
||||
STP (R10, R11), 0(R20)
|
||||
STP (R12, R13), 16(R20)
|
||||
MOVD R14, 32(R20)
|
||||
|
||||
RET
|
||||
-12
@@ -1,12 +0,0 @@
|
||||
// Copyright (c) 2021 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build !arm64 || !gc || purego
|
||||
// +build !arm64 !gc purego
|
||||
|
||||
package field
|
||||
|
||||
func (v *Element) carryPropagate() *Element {
|
||||
return v.carryPropagateGeneric()
|
||||
}
|
||||
+88
-82
@@ -12,20 +12,42 @@ type uint128 struct {
|
||||
lo, hi uint64
|
||||
}
|
||||
|
||||
// mul64 returns a * b.
|
||||
func mul64(a, b uint64) uint128 {
|
||||
// mul returns a * b.
|
||||
func mul(a, b uint64) uint128 {
|
||||
hi, lo := bits.Mul64(a, b)
|
||||
return uint128{lo, hi}
|
||||
}
|
||||
|
||||
// addMul64 returns v + a * b.
|
||||
func addMul64(v uint128, a, b uint64) uint128 {
|
||||
// addMul returns v + a * b.
|
||||
func addMul(v uint128, a, b uint64) uint128 {
|
||||
hi, lo := bits.Mul64(a, b)
|
||||
lo, c := bits.Add64(lo, v.lo, 0)
|
||||
hi, _ = bits.Add64(hi, v.hi, c)
|
||||
return uint128{lo, hi}
|
||||
}
|
||||
|
||||
// mul19 returns v * 19.
|
||||
func mul19(v uint64) uint64 {
|
||||
// Using this approach seems to yield better optimizations than *19.
|
||||
return v + (v+v<<3)<<1
|
||||
}
|
||||
|
||||
// addMul19 returns v + 19 * a * b, where a and b are at most 52 bits.
|
||||
func addMul19(v uint128, a, b uint64) uint128 {
|
||||
hi, lo := bits.Mul64(mul19(a), b)
|
||||
lo, c := bits.Add64(lo, v.lo, 0)
|
||||
hi, _ = bits.Add64(hi, v.hi, c)
|
||||
return uint128{lo, hi}
|
||||
}
|
||||
|
||||
// addMul38 returns v + 38 * a * b, where a and b are at most 52 bits.
|
||||
func addMul38(v uint128, a, b uint64) uint128 {
|
||||
hi, lo := bits.Mul64(mul19(a), b*2)
|
||||
lo, c := bits.Add64(lo, v.lo, 0)
|
||||
hi, _ = bits.Add64(hi, v.hi, c)
|
||||
return uint128{lo, hi}
|
||||
}
|
||||
|
||||
// shiftRightBy51 returns a >> 51. a is assumed to be at most 115 bits.
|
||||
func shiftRightBy51(a uint128) uint64 {
|
||||
return (a.hi << (64 - 51)) | (a.lo >> 51)
|
||||
@@ -76,45 +98,40 @@ func feMulGeneric(v, a, b *Element) {
|
||||
//
|
||||
// Finally we add up the columns into wide, overlapping limbs.
|
||||
|
||||
a1_19 := a1 * 19
|
||||
a2_19 := a2 * 19
|
||||
a3_19 := a3 * 19
|
||||
a4_19 := a4 * 19
|
||||
|
||||
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
|
||||
r0 := mul64(a0, b0)
|
||||
r0 = addMul64(r0, a1_19, b4)
|
||||
r0 = addMul64(r0, a2_19, b3)
|
||||
r0 = addMul64(r0, a3_19, b2)
|
||||
r0 = addMul64(r0, a4_19, b1)
|
||||
r0 := mul(a0, b0)
|
||||
r0 = addMul19(r0, a1, b4)
|
||||
r0 = addMul19(r0, a2, b3)
|
||||
r0 = addMul19(r0, a3, b2)
|
||||
r0 = addMul19(r0, a4, b1)
|
||||
|
||||
// r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2)
|
||||
r1 := mul64(a0, b1)
|
||||
r1 = addMul64(r1, a1, b0)
|
||||
r1 = addMul64(r1, a2_19, b4)
|
||||
r1 = addMul64(r1, a3_19, b3)
|
||||
r1 = addMul64(r1, a4_19, b2)
|
||||
r1 := mul(a0, b1)
|
||||
r1 = addMul(r1, a1, b0)
|
||||
r1 = addMul19(r1, a2, b4)
|
||||
r1 = addMul19(r1, a3, b3)
|
||||
r1 = addMul19(r1, a4, b2)
|
||||
|
||||
// r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3)
|
||||
r2 := mul64(a0, b2)
|
||||
r2 = addMul64(r2, a1, b1)
|
||||
r2 = addMul64(r2, a2, b0)
|
||||
r2 = addMul64(r2, a3_19, b4)
|
||||
r2 = addMul64(r2, a4_19, b3)
|
||||
r2 := mul(a0, b2)
|
||||
r2 = addMul(r2, a1, b1)
|
||||
r2 = addMul(r2, a2, b0)
|
||||
r2 = addMul19(r2, a3, b4)
|
||||
r2 = addMul19(r2, a4, b3)
|
||||
|
||||
// r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4
|
||||
r3 := mul64(a0, b3)
|
||||
r3 = addMul64(r3, a1, b2)
|
||||
r3 = addMul64(r3, a2, b1)
|
||||
r3 = addMul64(r3, a3, b0)
|
||||
r3 = addMul64(r3, a4_19, b4)
|
||||
r3 := mul(a0, b3)
|
||||
r3 = addMul(r3, a1, b2)
|
||||
r3 = addMul(r3, a2, b1)
|
||||
r3 = addMul(r3, a3, b0)
|
||||
r3 = addMul19(r3, a4, b4)
|
||||
|
||||
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
|
||||
r4 := mul64(a0, b4)
|
||||
r4 = addMul64(r4, a1, b3)
|
||||
r4 = addMul64(r4, a2, b2)
|
||||
r4 = addMul64(r4, a3, b1)
|
||||
r4 = addMul64(r4, a4, b0)
|
||||
r4 := mul(a0, b4)
|
||||
r4 = addMul(r4, a1, b3)
|
||||
r4 = addMul(r4, a2, b2)
|
||||
r4 = addMul(r4, a3, b1)
|
||||
r4 = addMul(r4, a4, b0)
|
||||
|
||||
// After the multiplication, we need to reduce (carry) the five coefficients
|
||||
// to obtain a result with limbs that are at most slightly larger than 2⁵¹,
|
||||
@@ -149,7 +166,7 @@ func feMulGeneric(v, a, b *Element) {
|
||||
c3 := shiftRightBy51(r3)
|
||||
c4 := shiftRightBy51(r4)
|
||||
|
||||
rr0 := r0.lo&maskLow51Bits + c4*19
|
||||
rr0 := r0.lo&maskLow51Bits + mul19(c4)
|
||||
rr1 := r1.lo&maskLow51Bits + c0
|
||||
rr2 := r2.lo&maskLow51Bits + c1
|
||||
rr3 := r3.lo&maskLow51Bits + c2
|
||||
@@ -158,8 +175,12 @@ func feMulGeneric(v, a, b *Element) {
|
||||
// Now all coefficients fit into 64-bit registers but are still too large to
|
||||
// be passed around as an Element. We therefore do one last carry chain,
|
||||
// where the carries will be small enough to fit in the wiggle room above 2⁵¹.
|
||||
*v = Element{rr0, rr1, rr2, rr3, rr4}
|
||||
v.carryPropagate()
|
||||
|
||||
v.l0 = rr0&maskLow51Bits + mul19(rr4>>51)
|
||||
v.l1 = rr1&maskLow51Bits + rr0>>51
|
||||
v.l2 = rr2&maskLow51Bits + rr1>>51
|
||||
v.l3 = rr3&maskLow51Bits + rr2>>51
|
||||
v.l4 = rr4&maskLow51Bits + rr3>>51
|
||||
}
|
||||
|
||||
func feSquareGeneric(v, a *Element) {
|
||||
@@ -190,44 +211,31 @@ func feSquareGeneric(v, a *Element) {
|
||||
// l0l4 19×l4l4 19×l3l4 19×l2l4 19×l1l4 =
|
||||
// --------------------------------------
|
||||
// r4 r3 r2 r1 r0
|
||||
//
|
||||
// With precomputed 2×, 19×, and 2×19× terms, we can compute each limb with
|
||||
// only three Mul64 and four Add64, instead of five and eight.
|
||||
|
||||
l0_2 := l0 * 2
|
||||
l1_2 := l1 * 2
|
||||
|
||||
l1_38 := l1 * 38
|
||||
l2_38 := l2 * 38
|
||||
l3_38 := l3 * 38
|
||||
|
||||
l3_19 := l3 * 19
|
||||
l4_19 := l4 * 19
|
||||
|
||||
// r0 = l0×l0 + 19×(l1×l4 + l2×l3 + l3×l2 + l4×l1) = l0×l0 + 19×2×(l1×l4 + l2×l3)
|
||||
r0 := mul64(l0, l0)
|
||||
r0 = addMul64(r0, l1_38, l4)
|
||||
r0 = addMul64(r0, l2_38, l3)
|
||||
r0 := mul(l0, l0)
|
||||
r0 = addMul38(r0, l1, l4)
|
||||
r0 = addMul38(r0, l2, l3)
|
||||
|
||||
// r1 = l0×l1 + l1×l0 + 19×(l2×l4 + l3×l3 + l4×l2) = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3
|
||||
r1 := mul64(l0_2, l1)
|
||||
r1 = addMul64(r1, l2_38, l4)
|
||||
r1 = addMul64(r1, l3_19, l3)
|
||||
r1 := mul(l0*2, l1)
|
||||
r1 = addMul38(r1, l2, l4)
|
||||
r1 = addMul19(r1, l3, l3)
|
||||
|
||||
// r2 = l0×l2 + l1×l1 + l2×l0 + 19×(l3×l4 + l4×l3) = 2×l0×l2 + l1×l1 + 19×2×l3×l4
|
||||
r2 := mul64(l0_2, l2)
|
||||
r2 = addMul64(r2, l1, l1)
|
||||
r2 = addMul64(r2, l3_38, l4)
|
||||
r2 := mul(l0*2, l2)
|
||||
r2 = addMul(r2, l1, l1)
|
||||
r2 = addMul38(r2, l3, l4)
|
||||
|
||||
// r3 = l0×l3 + l1×l2 + l2×l1 + l3×l0 + 19×l4×l4 = 2×l0×l3 + 2×l1×l2 + 19×l4×l4
|
||||
r3 := mul64(l0_2, l3)
|
||||
r3 = addMul64(r3, l1_2, l2)
|
||||
r3 = addMul64(r3, l4_19, l4)
|
||||
r3 := mul(l0*2, l3)
|
||||
r3 = addMul(r3, l1*2, l2)
|
||||
r3 = addMul19(r3, l4, l4)
|
||||
|
||||
// r4 = l0×l4 + l1×l3 + l2×l2 + l3×l1 + l4×l0 = 2×l0×l4 + 2×l1×l3 + l2×l2
|
||||
r4 := mul64(l0_2, l4)
|
||||
r4 = addMul64(r4, l1_2, l3)
|
||||
r4 = addMul64(r4, l2, l2)
|
||||
r4 := mul(l0*2, l4)
|
||||
r4 = addMul(r4, l1*2, l3)
|
||||
r4 = addMul(r4, l2, l2)
|
||||
|
||||
c0 := shiftRightBy51(r0)
|
||||
c1 := shiftRightBy51(r1)
|
||||
@@ -235,32 +243,30 @@ func feSquareGeneric(v, a *Element) {
|
||||
c3 := shiftRightBy51(r3)
|
||||
c4 := shiftRightBy51(r4)
|
||||
|
||||
rr0 := r0.lo&maskLow51Bits + c4*19
|
||||
rr0 := r0.lo&maskLow51Bits + mul19(c4)
|
||||
rr1 := r1.lo&maskLow51Bits + c0
|
||||
rr2 := r2.lo&maskLow51Bits + c1
|
||||
rr3 := r3.lo&maskLow51Bits + c2
|
||||
rr4 := r4.lo&maskLow51Bits + c3
|
||||
|
||||
*v = Element{rr0, rr1, rr2, rr3, rr4}
|
||||
v.carryPropagate()
|
||||
v.l0 = rr0&maskLow51Bits + mul19(rr4>>51)
|
||||
v.l1 = rr1&maskLow51Bits + rr0>>51
|
||||
v.l2 = rr2&maskLow51Bits + rr1>>51
|
||||
v.l3 = rr3&maskLow51Bits + rr2>>51
|
||||
v.l4 = rr4&maskLow51Bits + rr3>>51
|
||||
}
|
||||
|
||||
// carryPropagateGeneric brings the limbs below 52 bits by applying the reduction
|
||||
// carryPropagate brings the limbs below 52 bits by applying the reduction
|
||||
// identity (a * 2²⁵⁵ + b = a * 19 + b) to the l4 carry.
|
||||
func (v *Element) carryPropagateGeneric() *Element {
|
||||
c0 := v.l0 >> 51
|
||||
c1 := v.l1 >> 51
|
||||
c2 := v.l2 >> 51
|
||||
c3 := v.l3 >> 51
|
||||
c4 := v.l4 >> 51
|
||||
|
||||
// c4 is at most 64 - 51 = 13 bits, so c4*19 is at most 18 bits, and
|
||||
func (v *Element) carryPropagate() *Element {
|
||||
// (l4>>51) is at most 64 - 51 = 13 bits, so (l4>>51)*19 is at most 18 bits, and
|
||||
// the final l0 will be at most 52 bits. Similarly for the rest.
|
||||
v.l0 = v.l0&maskLow51Bits + c4*19
|
||||
v.l1 = v.l1&maskLow51Bits + c0
|
||||
v.l2 = v.l2&maskLow51Bits + c1
|
||||
v.l3 = v.l3&maskLow51Bits + c2
|
||||
v.l4 = v.l4&maskLow51Bits + c3
|
||||
l0 := v.l0
|
||||
v.l0 = v.l0&maskLow51Bits + mul19(v.l4>>51)
|
||||
v.l4 = v.l4&maskLow51Bits + v.l3>>51
|
||||
v.l3 = v.l3&maskLow51Bits + v.l2>>51
|
||||
v.l2 = v.l2&maskLow51Bits + v.l1>>51
|
||||
v.l1 = v.l1&maskLow51Bits + l0>>51
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
+53
@@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
if [ "$#" -ne 1 ]; then
|
||||
echo "Usage: $0 <tag>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
TAG="$1"
|
||||
TMPDIR="$(mktemp -d)"
|
||||
|
||||
cleanup() {
|
||||
rm -rf "$TMPDIR"
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
command -v git >/dev/null
|
||||
command -v git-filter-repo >/dev/null
|
||||
|
||||
if [ -d "$HOME/go/.git" ]; then
|
||||
REFERENCE=(--reference "$HOME/go" --dissociate)
|
||||
else
|
||||
REFERENCE=()
|
||||
fi
|
||||
|
||||
git -c advice.detachedHead=false clone --no-checkout "${REFERENCE[@]}" \
|
||||
-b "$TAG" https://go.googlesource.com/go.git "$TMPDIR"
|
||||
|
||||
# Simplify the history graph by removing the dev.boringcrypto branches, whose
|
||||
# merges end up empty after grafting anyway. This also fixes a weird quirk
|
||||
# (maybe a git-filter-repo bug?) where only one file from an old path,
|
||||
# src/crypto/ed25519/internal/edwards25519/const.go, would still exist in the
|
||||
# filtered repo.
|
||||
git -C "$TMPDIR" replace --graft f771edd7f9 99f1bf54eb
|
||||
git -C "$TMPDIR" replace --graft 109c13b64f c2f96e686f
|
||||
git -C "$TMPDIR" replace --graft aa4da4f189 912f075047
|
||||
|
||||
git -C "$TMPDIR" filter-repo --force \
|
||||
--paths-from-file /dev/stdin \
|
||||
--prune-empty always \
|
||||
--prune-degenerate always \
|
||||
--tag-callback 'tag.skip()' <<'EOF'
|
||||
src/crypto/internal/fips140/edwards25519
|
||||
src/crypto/internal/edwards25519
|
||||
src/crypto/ed25519/internal/edwards25519
|
||||
EOF
|
||||
|
||||
git fetch "$TMPDIR"
|
||||
git update-ref "refs/heads/upstream/$TAG" FETCH_HEAD
|
||||
|
||||
echo
|
||||
echo "Fetched upstream history up to $TAG. Merge with:"
|
||||
echo -e "\tgit merge --no-ff --no-commit --allow-unrelated-histories upstream/$TAG"
|
||||
+18
-9
@@ -7,6 +7,7 @@ package edwards25519
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
// A Scalar is an integer modulo
|
||||
@@ -179,15 +180,23 @@ func isReduced(s []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := len(s) - 1; i >= 0; i-- {
|
||||
switch {
|
||||
case s[i] > scalarMinusOneBytes[i]:
|
||||
return false
|
||||
case s[i] < scalarMinusOneBytes[i]:
|
||||
return true
|
||||
}
|
||||
}
|
||||
return true
|
||||
s0 := binary.LittleEndian.Uint64(s[:8])
|
||||
s1 := binary.LittleEndian.Uint64(s[8:16])
|
||||
s2 := binary.LittleEndian.Uint64(s[16:24])
|
||||
s3 := binary.LittleEndian.Uint64(s[24:])
|
||||
|
||||
l0 := binary.LittleEndian.Uint64(scalarMinusOneBytes[:8])
|
||||
l1 := binary.LittleEndian.Uint64(scalarMinusOneBytes[8:16])
|
||||
l2 := binary.LittleEndian.Uint64(scalarMinusOneBytes[16:24])
|
||||
l3 := binary.LittleEndian.Uint64(scalarMinusOneBytes[24:])
|
||||
|
||||
// Do a constant time subtraction chain scalarMinusOneBytes - s. If there is
|
||||
// a borrow at the end, then s > scalarMinusOneBytes.
|
||||
_, b := bits.Sub64(l0, s0, 0)
|
||||
_, b = bits.Sub64(l1, s1, b)
|
||||
_, b = bits.Sub64(l2, s2, b)
|
||||
_, b = bits.Sub64(l3, s3, b)
|
||||
return b == 0
|
||||
}
|
||||
|
||||
// SetBytesWithClamping applies the buffer pruning described in RFC 8032,
|
||||
|
||||
+1
-3
@@ -4,9 +4,7 @@
|
||||
|
||||
package edwards25519
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
)
|
||||
import "crypto/subtle"
|
||||
|
||||
// A dynamic lookup table for variable-base, constant-time scalar muls.
|
||||
type projLookupTable struct {
|
||||
|
||||
+2
-2
@@ -14,8 +14,8 @@
|
||||
//
|
||||
// [Antithesis Go SDK]: https://antithesis.com/docs/using_antithesis/sdk/go/
|
||||
// [Antithesis platform]: https://antithesis.com
|
||||
// [test properties]: https://antithesis.com/docs/using_antithesis/properties/
|
||||
// [workload]: https://antithesis.com/docs/getting_started/first_test/
|
||||
// [test properties]: https://antithesis.com/docs/properties_assertions/properties/
|
||||
// [workload]: https://antithesis.com/docs/test_templates/first_test/
|
||||
// [antithesis-go-generator]: https://antithesis.com/docs/using_antithesis/sdk/go/instrumentor/
|
||||
// [triage report]: https://antithesis.com/docs/reports/
|
||||
// [here]: https://antithesis.com/docs/using_antithesis/sdk/fallback/
|
||||
|
||||
+1
-1
@@ -3,7 +3,7 @@ package internal
|
||||
// --------------------------------------------------------------------------------
|
||||
// Versions
|
||||
// --------------------------------------------------------------------------------
|
||||
const SDK_Version = "0.6.0"
|
||||
const SDK_Version = "0.7.0"
|
||||
const Protocol_Version = "1.1.0"
|
||||
|
||||
// --------------------------------------------------------------------------------
|
||||
|
||||
+7
-3
@@ -18,8 +18,8 @@ Alex Snast <alexsn at fb.com>
|
||||
Alexey Palazhchenko <alexey.palazhchenko at gmail.com>
|
||||
Andrew Reid <andrew.reid at tixtrack.com>
|
||||
Animesh Ray <mail.rayanimesh at gmail.com>
|
||||
Arne Hormann <arnehormann at gmail.com>
|
||||
Ariel Mashraki <ariel at mashraki.co.il>
|
||||
Arne Hormann <arnehormann at gmail.com>
|
||||
Artur Melanchyk <artur.melanchyk@gmail.com>
|
||||
Asta Xie <xiemengjun at gmail.com>
|
||||
B Lamarche <blam413 at gmail.com>
|
||||
@@ -38,6 +38,7 @@ Daniel Montoya <dsmontoyam at gmail.com>
|
||||
Daniel Nichter <nil at codenode.com>
|
||||
Daniël van Eeden <git at myname.nl>
|
||||
Dave Protasowski <dprotaso at gmail.com>
|
||||
Demouth <yuya at demouth.net>
|
||||
Diego Dupin <diego.dupin at gmail.com>
|
||||
Dirkjan Bussink <d.bussink at gmail.com>
|
||||
DisposaBoy <disposaboy at dby.me>
|
||||
@@ -66,6 +67,7 @@ Jeff Hodges <jeff at somethingsimilar.com>
|
||||
Jeffrey Charles <jeffreycharles at gmail.com>
|
||||
Jennifer Purevsuren <jennifer at dolthub.com>
|
||||
Jerome Meyer <jxmeyer at gmail.com>
|
||||
Jiabin Zhang <jiabin.z at qq.com>
|
||||
Jiajia Zhong <zhong2plus at gmail.com>
|
||||
Jian Zhen <zhenjl at gmail.com>
|
||||
Joe Mann <contact at joemann.co.uk>
|
||||
@@ -85,10 +87,12 @@ Linh Tran Tuan <linhduonggnu at gmail.com>
|
||||
Lion Yang <lion at aosc.xyz>
|
||||
Luca Looz <luca.looz92 at gmail.com>
|
||||
Lucas Liu <extrafliu at gmail.com>
|
||||
Lunny Xiao <xiaolunwen at gmail.com>
|
||||
Luke Scott <luke at webconnex.com>
|
||||
Lunny Xiao <xiaolunwen at gmail.com>
|
||||
Maciej Zimnoch <maciej.zimnoch at codilime.com>
|
||||
Michael Woolnough <michael.woolnough at gmail.com>
|
||||
Minh Quang <minhquang4334 at gmail.com>
|
||||
Morgan Tocker <tocker at gmail.com>
|
||||
Nao Yokotsuka <yokotukanao at gmail.com>
|
||||
Nathanial Murphy <nathanial.murphy at gmail.com>
|
||||
Nicola Peduzzi <thenikso at gmail.com>
|
||||
@@ -99,7 +103,6 @@ Paul Bonser <misterpib at gmail.com>
|
||||
Paulius Lozys <pauliuslozys at gmail.com>
|
||||
Peter Schultz <peter.schultz at classmarkets.com>
|
||||
Phil Porada <philporada at gmail.com>
|
||||
Minh Quang <minhquang4334 at gmail.com>
|
||||
Rebecca Chin <rchin at pivotal.io>
|
||||
Reed Allman <rdallman10 at gmail.com>
|
||||
Richard Wilkes <wilkes at me.com>
|
||||
@@ -134,6 +137,7 @@ Ziheng Lyu <zihenglv at gmail.com>
|
||||
# Organizations
|
||||
|
||||
Barracuda Networks, Inc.
|
||||
Block, Inc.
|
||||
Counting Ltd.
|
||||
Defined Networking Inc.
|
||||
DigitalOcean Inc.
|
||||
|
||||
+16
-3
@@ -1,13 +1,26 @@
|
||||
# Changelog
|
||||
|
||||
## v1.10.0 (2026-04-28)
|
||||
|
||||
* Fix `getSystemVar("max_allowed_packet")` potentially returned wrong value. (#1754)
|
||||
This affects only when `maxAllowedPacket=0` is set.
|
||||
|
||||
* Bump filippo.io/edwards25519 from 1.1.1 to 1.2.0. (#1756)
|
||||
While older versions have reported CVEs, they do not affect go-mysql.
|
||||
|
||||
* Update Go versions to 1.24-1.26. (#1763)
|
||||
|
||||
* Enhance interpolateParams to correctly handle placeholders. (#1732)
|
||||
The question mark (?) within strings and comments will no longer be treated as a placeholder.
|
||||
|
||||
|
||||
## v1.9.3 (2025-06-13)
|
||||
|
||||
* `tx.Commit()` and `tx.Rollback()` returned `ErrInvalidConn` always.
|
||||
Now they return cached real error if present. (#1690)
|
||||
|
||||
* Optimize reading small resultsets to fix performance regression
|
||||
introduced by compression protocol support. (#1707)
|
||||
|
||||
* Optimize reading small result sets to fix a performance regression
|
||||
introduced by compression protocol support. (`#1707`)
|
||||
* Fix `db.Ping()` on compressed connection. (#1723)
|
||||
|
||||
|
||||
|
||||
+5
-2
@@ -1,5 +1,8 @@
|
||||
# Go-MySQL-Driver
|
||||
|
||||
[](https://deepwiki.com/go-sql-driver/mysql)
|
||||
|
||||
|
||||
A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) package
|
||||
|
||||

|
||||
@@ -42,8 +45,8 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac
|
||||
|
||||
## Requirements
|
||||
|
||||
* Go 1.21 or higher. We aim to support the 3 latest versions of Go.
|
||||
* MySQL (5.7+) and MariaDB (10.5+) are supported.
|
||||
* Go 1.24 or higher. We aim to support the 3 latest versions of Go.
|
||||
* MySQL (5.7+) and MariaDB (10.5+) are supported by maintainers.
|
||||
* [TiDB](https://github.com/pingcap/tidb) is supported by PingCAP.
|
||||
* Do not ask questions about TiDB in our issue tracker or forum.
|
||||
* [Document](https://docs.pingcap.com/tidb/v6.1/dev-guide-sample-application-golang)
|
||||
|
||||
+1
-1
@@ -305,7 +305,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
|
||||
if !mc.cfg.AllowNativePasswords {
|
||||
return nil, ErrNativePassword
|
||||
}
|
||||
// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.4.5/page_protocol_connection_phase_authentication_methods_native_password_authentication.html
|
||||
// Native password authentication only need and will need 20-byte challenge.
|
||||
authResp := scramblePassword(authData[:20], mc.cfg.Passwd)
|
||||
return authResp, nil
|
||||
|
||||
-1
@@ -7,7 +7,6 @@
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
//go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd || solaris || illumos
|
||||
// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos
|
||||
|
||||
package mysql
|
||||
|
||||
|
||||
-1
@@ -7,7 +7,6 @@
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
//go:build !linux && !darwin && !dragonfly && !freebsd && !netbsd && !openbsd && !solaris && !illumos
|
||||
// +build !linux,!darwin,!dragonfly,!freebsd,!netbsd,!openbsd,!solaris,!illumos
|
||||
|
||||
package mysql
|
||||
|
||||
|
||||
+185
-90
@@ -33,7 +33,8 @@ type mysqlConn struct {
|
||||
connector *connector
|
||||
maxAllowedPacket int
|
||||
maxWriteSize int
|
||||
flags clientFlag
|
||||
capabilities capabilityFlag
|
||||
extCapabilities extendedCapabilityFlag
|
||||
status statusFlag
|
||||
sequence uint8
|
||||
compressSequence uint8
|
||||
@@ -171,7 +172,7 @@ func (mc *mysqlConn) close() {
|
||||
}
|
||||
|
||||
// Closes the network connection and unsets internal variables. Do not call this
|
||||
// function after successfully authentication, call Close instead. This function
|
||||
// function after successful authentication, call Close instead. This function
|
||||
// is called before auth or on auth failure because MySQL will have already
|
||||
// closed the network connection.
|
||||
func (mc *mysqlConn) cleanup() {
|
||||
@@ -223,13 +224,21 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
|
||||
columnCount, err := stmt.readPrepareResultPacket()
|
||||
if err == nil {
|
||||
if stmt.paramCount > 0 {
|
||||
if err = mc.readUntilEOF(); err != nil {
|
||||
if err = mc.skipColumns(stmt.paramCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if columnCount > 0 {
|
||||
err = mc.readUntilEOF()
|
||||
if mc.extCapabilities&clientCacheMetadata != 0 {
|
||||
if stmt.columns, err = mc.readColumns(int(columnCount), nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err = mc.skipColumns(int(columnCount)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,100 +246,184 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
|
||||
// Number of ? should be same to len(args)
|
||||
if strings.Count(query, "?") != len(args) {
|
||||
return "", driver.ErrSkip
|
||||
}
|
||||
noBackslashEscapes := (mc.status & statusNoBackslashEscapes) != 0
|
||||
const (
|
||||
stateNormal = iota
|
||||
stateString
|
||||
stateEscape
|
||||
stateEOLComment
|
||||
stateSlashStarComment
|
||||
stateBacktick
|
||||
)
|
||||
|
||||
const (
|
||||
QUOTE_BYTE = byte('\'')
|
||||
DBL_QUOTE_BYTE = byte('"')
|
||||
BACKSLASH_BYTE = byte('\\')
|
||||
QUESTION_MARK_BYTE = byte('?')
|
||||
SLASH_BYTE = byte('/')
|
||||
STAR_BYTE = byte('*')
|
||||
HASH_BYTE = byte('#')
|
||||
MINUS_BYTE = byte('-')
|
||||
LINE_FEED_BYTE = byte('\n')
|
||||
BACKTICK_BYTE = byte('`')
|
||||
)
|
||||
|
||||
buf, err := mc.buf.takeCompleteBuffer()
|
||||
if err != nil {
|
||||
// can not take the buffer. Something must be wrong with the connection
|
||||
mc.cleanup()
|
||||
// interpolateParams would be called before sending any query.
|
||||
// So its safe to retry.
|
||||
return "", driver.ErrBadConn
|
||||
}
|
||||
buf = buf[:0]
|
||||
state := stateNormal
|
||||
singleQuotes := false
|
||||
lastChar := byte(0)
|
||||
argPos := 0
|
||||
lenQuery := len(query)
|
||||
lastIdx := 0
|
||||
|
||||
for i := 0; i < len(query); i++ {
|
||||
q := strings.IndexByte(query[i:], '?')
|
||||
if q == -1 {
|
||||
buf = append(buf, query[i:]...)
|
||||
break
|
||||
}
|
||||
buf = append(buf, query[i:i+q]...)
|
||||
i += q
|
||||
|
||||
arg := args[argPos]
|
||||
argPos++
|
||||
|
||||
if arg == nil {
|
||||
buf = append(buf, "NULL"...)
|
||||
for i := range lenQuery {
|
||||
currentChar := query[i]
|
||||
if state == stateEscape && !((currentChar == QUOTE_BYTE && singleQuotes) || (currentChar == DBL_QUOTE_BYTE && !singleQuotes)) {
|
||||
state = stateString
|
||||
lastChar = currentChar
|
||||
continue
|
||||
}
|
||||
|
||||
switch v := arg.(type) {
|
||||
case int64:
|
||||
buf = strconv.AppendInt(buf, v, 10)
|
||||
case uint64:
|
||||
// Handle uint64 explicitly because our custom ConvertValue emits unsigned values
|
||||
buf = strconv.AppendUint(buf, v, 10)
|
||||
case float64:
|
||||
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
|
||||
case bool:
|
||||
if v {
|
||||
buf = append(buf, '1')
|
||||
} else {
|
||||
buf = append(buf, '0')
|
||||
switch currentChar {
|
||||
case STAR_BYTE:
|
||||
if state == stateNormal && lastChar == SLASH_BYTE {
|
||||
state = stateSlashStarComment
|
||||
}
|
||||
case time.Time:
|
||||
if v.IsZero() {
|
||||
buf = append(buf, "'0000-00-00'"...)
|
||||
} else {
|
||||
buf = append(buf, '\'')
|
||||
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
buf = append(buf, '\'')
|
||||
case SLASH_BYTE:
|
||||
if state == stateSlashStarComment && lastChar == STAR_BYTE {
|
||||
state = stateNormal
|
||||
// Clear lastChar so the '/' that closed the comment isn't
|
||||
// reused to start a new comment with a following '*'.
|
||||
lastChar = 0
|
||||
continue
|
||||
}
|
||||
case json.RawMessage:
|
||||
buf = append(buf, '\'')
|
||||
if mc.status&statusNoBackslashEscapes == 0 {
|
||||
buf = escapeBytesBackslash(buf, v)
|
||||
} else {
|
||||
buf = escapeBytesQuotes(buf, v)
|
||||
case HASH_BYTE:
|
||||
if state == stateNormal {
|
||||
state = stateEOLComment
|
||||
}
|
||||
buf = append(buf, '\'')
|
||||
case []byte:
|
||||
if v == nil {
|
||||
buf = append(buf, "NULL"...)
|
||||
} else {
|
||||
buf = append(buf, "_binary'"...)
|
||||
if mc.status&statusNoBackslashEscapes == 0 {
|
||||
buf = escapeBytesBackslash(buf, v)
|
||||
case MINUS_BYTE:
|
||||
if state == stateNormal && lastChar == MINUS_BYTE {
|
||||
// -- only starts a comment if followed by whitespace or control char
|
||||
if i+1 < lenQuery {
|
||||
nextChar := query[i+1]
|
||||
if nextChar == ' ' || nextChar == '\t' || nextChar == '\n' || nextChar == '\r' {
|
||||
state = stateEOLComment
|
||||
}
|
||||
} else {
|
||||
buf = escapeBytesQuotes(buf, v)
|
||||
state = stateEOLComment
|
||||
}
|
||||
buf = append(buf, '\'')
|
||||
}
|
||||
case string:
|
||||
buf = append(buf, '\'')
|
||||
if mc.status&statusNoBackslashEscapes == 0 {
|
||||
buf = escapeStringBackslash(buf, v)
|
||||
} else {
|
||||
buf = escapeStringQuotes(buf, v)
|
||||
case LINE_FEED_BYTE:
|
||||
if state == stateEOLComment {
|
||||
state = stateNormal
|
||||
}
|
||||
buf = append(buf, '\'')
|
||||
default:
|
||||
return "", driver.ErrSkip
|
||||
}
|
||||
case DBL_QUOTE_BYTE:
|
||||
if state == stateNormal {
|
||||
state = stateString
|
||||
singleQuotes = false
|
||||
} else if state == stateString && !singleQuotes {
|
||||
state = stateNormal
|
||||
} else if state == stateEscape {
|
||||
state = stateString
|
||||
}
|
||||
case QUOTE_BYTE:
|
||||
if state == stateNormal {
|
||||
state = stateString
|
||||
singleQuotes = true
|
||||
} else if state == stateString && singleQuotes {
|
||||
state = stateNormal
|
||||
} else if state == stateEscape {
|
||||
state = stateString
|
||||
}
|
||||
case BACKSLASH_BYTE:
|
||||
if state == stateString && !noBackslashEscapes {
|
||||
state = stateEscape
|
||||
}
|
||||
case QUESTION_MARK_BYTE:
|
||||
if state == stateNormal {
|
||||
if argPos >= len(args) {
|
||||
return "", driver.ErrSkip
|
||||
}
|
||||
buf = append(buf, query[lastIdx:i]...)
|
||||
arg := args[argPos]
|
||||
argPos++
|
||||
|
||||
if len(buf)+4 > mc.maxAllowedPacket {
|
||||
return "", driver.ErrSkip
|
||||
if arg == nil {
|
||||
buf = append(buf, "NULL"...)
|
||||
lastIdx = i + 1
|
||||
break
|
||||
}
|
||||
|
||||
switch v := arg.(type) {
|
||||
case int64:
|
||||
buf = strconv.AppendInt(buf, v, 10)
|
||||
case uint64:
|
||||
buf = strconv.AppendUint(buf, v, 10)
|
||||
case float64:
|
||||
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
|
||||
case bool:
|
||||
if v {
|
||||
buf = append(buf, '1')
|
||||
} else {
|
||||
buf = append(buf, '0')
|
||||
}
|
||||
case time.Time:
|
||||
if v.IsZero() {
|
||||
buf = append(buf, "'0000-00-00'"...)
|
||||
} else {
|
||||
buf = append(buf, '\'')
|
||||
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
buf = append(buf, '\'')
|
||||
}
|
||||
case json.RawMessage:
|
||||
if noBackslashEscapes {
|
||||
buf = escapeBytesQuotes(buf, v, false)
|
||||
} else {
|
||||
buf = escapeBytesBackslash(buf, v, false)
|
||||
}
|
||||
case []byte:
|
||||
if v == nil {
|
||||
buf = append(buf, "NULL"...)
|
||||
} else {
|
||||
if noBackslashEscapes {
|
||||
buf = escapeBytesQuotes(buf, v, true)
|
||||
} else {
|
||||
buf = escapeBytesBackslash(buf, v, true)
|
||||
}
|
||||
}
|
||||
case string:
|
||||
if noBackslashEscapes {
|
||||
buf = escapeStringQuotes(buf, v)
|
||||
} else {
|
||||
buf = escapeStringBackslash(buf, v)
|
||||
}
|
||||
default:
|
||||
return "", driver.ErrSkip
|
||||
}
|
||||
|
||||
if len(buf)+4 > mc.maxAllowedPacket {
|
||||
return "", driver.ErrSkip
|
||||
}
|
||||
lastIdx = i + 1
|
||||
}
|
||||
case BACKTICK_BYTE:
|
||||
if state == stateBacktick {
|
||||
state = stateNormal
|
||||
} else if state == stateNormal {
|
||||
state = stateBacktick
|
||||
}
|
||||
}
|
||||
lastChar = currentChar
|
||||
}
|
||||
buf = append(buf, query[lastIdx:]...)
|
||||
if argPos != len(args) {
|
||||
return "", driver.ErrSkip
|
||||
}
|
||||
@@ -370,19 +463,19 @@ func (mc *mysqlConn) exec(query string) error {
|
||||
}
|
||||
|
||||
// Read Result
|
||||
resLen, err := handleOk.readResultSetHeaderPacket()
|
||||
resLen, _, err := handleOk.readResultSetHeaderPacket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resLen > 0 {
|
||||
// columns
|
||||
if err := mc.readUntilEOF(); err != nil {
|
||||
if err := mc.skipColumns(resLen); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// rows
|
||||
if err := mc.readUntilEOF(); err != nil {
|
||||
if err := mc.skipRows(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -419,7 +512,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
|
||||
|
||||
// Read Result
|
||||
var resLen int
|
||||
resLen, err = handleOk.readResultSetHeaderPacket()
|
||||
resLen, _, err = handleOk.readResultSetHeaderPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -439,21 +532,20 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
|
||||
}
|
||||
|
||||
// Columns
|
||||
rows.rs.columns, err = mc.readColumns(resLen)
|
||||
rows.rs.columns, err = mc.readColumns(resLen, nil)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
// Gets the value of the given MySQL System Variable
|
||||
// The returned byte slice is only valid until the next read
|
||||
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
|
||||
func (mc *mysqlConn) getSystemVar(name string) (string, error) {
|
||||
// Send command
|
||||
handleOk := mc.clearResult()
|
||||
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Read Result
|
||||
resLen, err := handleOk.readResultSetHeaderPacket()
|
||||
resLen, _, err := handleOk.readResultSetHeaderPacket()
|
||||
if err == nil {
|
||||
rows := new(textRows)
|
||||
rows.mc = mc
|
||||
@@ -461,17 +553,20 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
|
||||
|
||||
if resLen > 0 {
|
||||
// Columns
|
||||
if err := mc.readUntilEOF(); err != nil {
|
||||
return nil, err
|
||||
if err := mc.skipColumns(resLen); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
dest := make([]driver.Value, resLen)
|
||||
if err = rows.readRow(dest); err == nil {
|
||||
return dest[0].([]byte), mc.readUntilEOF()
|
||||
// Convert to string before skipRows, which may
|
||||
// overwrite the read buffer that dest[0] points into.
|
||||
val := string(dest[0].([]byte))
|
||||
return val, mc.skipRows()
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
// cancel is called when the query has canceled.
|
||||
|
||||
+7
-5
@@ -42,7 +42,7 @@ func encodeConnectionAttributes(cfg *Config) string {
|
||||
}
|
||||
|
||||
// user-defined connection attributes
|
||||
for _, connAttr := range strings.Split(cfg.ConnectionAttributes, ",") {
|
||||
for connAttr := range strings.SplitSeq(cfg.ConnectionAttributes, ",") {
|
||||
k, v, found := strings.Cut(connAttr, ":")
|
||||
if !found {
|
||||
continue
|
||||
@@ -131,7 +131,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||
mc.buf = newBuffer()
|
||||
|
||||
// Reading Handshake Initialization Packet
|
||||
authData, plugin, err := mc.readHandshakePacket()
|
||||
authData, serverCapabilities, serverExtCapabilities, plugin, err := mc.readHandshakePacket()
|
||||
if err != nil {
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
@@ -153,6 +153,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
mc.initCapabilities(serverCapabilities, serverExtCapabilities, mc.cfg)
|
||||
if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
@@ -161,13 +162,14 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||
// Handle response to auth packet, switch methods if possible
|
||||
if err = mc.handleAuthResult(authData, plugin); err != nil {
|
||||
// Authentication failed and MySQL has already closed the connection
|
||||
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
|
||||
// (https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase.html#sect_protocol_connection_phase_fast_path_fails).
|
||||
// Do not send COM_QUIT, just cleanup and return the error.
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
|
||||
// compression is enabled after auth, not right after sending handshake response.
|
||||
if mc.capabilities&clientCompress > 0 {
|
||||
mc.compress = true
|
||||
mc.compIO = newCompIO(mc)
|
||||
}
|
||||
@@ -180,7 +182,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||
mc.Close()
|
||||
return nil, err
|
||||
}
|
||||
n, err := strconv.Atoi(string(maxap))
|
||||
n, err := strconv.Atoi(maxap)
|
||||
if err != nil {
|
||||
mc.Close()
|
||||
return nil, fmt.Errorf("invalid max_allowed_packet value (%q): %w", maxap, err)
|
||||
|
||||
+17
-4
@@ -32,7 +32,7 @@ const (
|
||||
)
|
||||
|
||||
// MySQL constants documentation:
|
||||
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/latest/PAGE_PROTOCOL.html
|
||||
|
||||
const (
|
||||
iOK byte = 0x00
|
||||
@@ -42,11 +42,12 @@ const (
|
||||
iERR byte = 0xff
|
||||
)
|
||||
|
||||
// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
|
||||
type clientFlag uint32
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__capabilities__flags.html
|
||||
// https://mariadb.com/kb/en/connection/#capabilities
|
||||
type capabilityFlag uint32
|
||||
|
||||
const (
|
||||
clientLongPassword clientFlag = 1 << iota
|
||||
clientMySQL capabilityFlag = 1 << iota
|
||||
clientFoundRows
|
||||
clientLongFlag
|
||||
clientConnectWithDB
|
||||
@@ -73,6 +74,18 @@ const (
|
||||
clientDeprecateEOF
|
||||
)
|
||||
|
||||
// https://mariadb.com/kb/en/connection/#capabilities
|
||||
type extendedCapabilityFlag uint32
|
||||
|
||||
const (
|
||||
progressIndicator extendedCapabilityFlag = 1 << iota
|
||||
clientComMulti
|
||||
clientStmtBulkOperations
|
||||
clientExtendedMetadata
|
||||
clientCacheMetadata
|
||||
clientUnitBulkResult
|
||||
)
|
||||
|
||||
const (
|
||||
comQuit byte = iota + 1
|
||||
comInitDB
|
||||
|
||||
+4
-5
@@ -15,6 +15,7 @@ import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/url"
|
||||
@@ -157,9 +158,7 @@ func (cfg *Config) Clone() *Config {
|
||||
}
|
||||
if len(cp.Params) > 0 {
|
||||
cp.Params = make(map[string]string, len(cfg.Params))
|
||||
for k, v := range cfg.Params {
|
||||
cp.Params[k] = v
|
||||
}
|
||||
maps.Copy(cp.Params, cfg.Params)
|
||||
}
|
||||
if cfg.pubKey != nil {
|
||||
cp.pubKey = &rsa.PublicKey{
|
||||
@@ -414,7 +413,7 @@ func ParseDSN(dsn string) (cfg *Config, err error) {
|
||||
if dsn[j] == '@' {
|
||||
// username[:password]
|
||||
// Find the first ':' in dsn[:j]
|
||||
for k = 0; k < j; k++ {
|
||||
for k = 0; k < j; k++ { // We cannot use k = range j here, because we use dsn[:k] below
|
||||
if dsn[k] == ':' {
|
||||
cfg.Passwd = dsn[k+1 : j]
|
||||
break
|
||||
@@ -477,7 +476,7 @@ func ParseDSN(dsn string) (cfg *Config, err error) {
|
||||
// parseDSNParams parses the DSN "query string"
|
||||
// Values must be url.QueryEscape'ed
|
||||
func parseDSNParams(cfg *Config, params string) (err error) {
|
||||
for _, v := range strings.Split(params, "&") {
|
||||
for v := range strings.SplitSeq(params, "&") {
|
||||
key, value, found := strings.Cut(v, "=")
|
||||
if !found {
|
||||
continue
|
||||
|
||||
+21
-17
@@ -120,23 +120,24 @@ func (mf *mysqlField) typeDatabaseName() string {
|
||||
}
|
||||
|
||||
var (
|
||||
scanTypeFloat32 = reflect.TypeOf(float32(0))
|
||||
scanTypeFloat64 = reflect.TypeOf(float64(0))
|
||||
scanTypeInt8 = reflect.TypeOf(int8(0))
|
||||
scanTypeInt16 = reflect.TypeOf(int16(0))
|
||||
scanTypeInt32 = reflect.TypeOf(int32(0))
|
||||
scanTypeInt64 = reflect.TypeOf(int64(0))
|
||||
scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
|
||||
scanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
|
||||
scanTypeNullTime = reflect.TypeOf(sql.NullTime{})
|
||||
scanTypeUint8 = reflect.TypeOf(uint8(0))
|
||||
scanTypeUint16 = reflect.TypeOf(uint16(0))
|
||||
scanTypeUint32 = reflect.TypeOf(uint32(0))
|
||||
scanTypeUint64 = reflect.TypeOf(uint64(0))
|
||||
scanTypeString = reflect.TypeOf("")
|
||||
scanTypeNullString = reflect.TypeOf(sql.NullString{})
|
||||
scanTypeBytes = reflect.TypeOf([]byte{})
|
||||
scanTypeUnknown = reflect.TypeOf(new(any))
|
||||
scanTypeFloat32 = reflect.TypeFor[float32]()
|
||||
scanTypeFloat64 = reflect.TypeFor[float64]()
|
||||
scanTypeInt8 = reflect.TypeFor[int8]()
|
||||
scanTypeInt16 = reflect.TypeFor[int16]()
|
||||
scanTypeInt32 = reflect.TypeFor[int32]()
|
||||
scanTypeInt64 = reflect.TypeFor[int64]()
|
||||
scanTypeNullFloat = reflect.TypeFor[sql.NullFloat64]()
|
||||
scanTypeNullInt = reflect.TypeFor[sql.NullInt64]()
|
||||
scanTypeNullUint = reflect.TypeFor[sql.Null[uint64]]()
|
||||
scanTypeNullTime = reflect.TypeFor[sql.NullTime]()
|
||||
scanTypeUint8 = reflect.TypeFor[uint8]()
|
||||
scanTypeUint16 = reflect.TypeFor[uint16]()
|
||||
scanTypeUint32 = reflect.TypeFor[uint32]()
|
||||
scanTypeUint64 = reflect.TypeFor[uint64]()
|
||||
scanTypeString = reflect.TypeFor[string]()
|
||||
scanTypeNullString = reflect.TypeFor[sql.NullString]()
|
||||
scanTypeBytes = reflect.TypeFor[[]byte]()
|
||||
scanTypeUnknown = reflect.TypeFor[*any]()
|
||||
)
|
||||
|
||||
type mysqlField struct {
|
||||
@@ -185,6 +186,9 @@ func (mf *mysqlField) scanType() reflect.Type {
|
||||
}
|
||||
return scanTypeInt64
|
||||
}
|
||||
if mf.flags&flagUnsigned != 0 {
|
||||
return scanTypeNullUint
|
||||
}
|
||||
return scanTypeNullInt
|
||||
|
||||
case fieldTypeFloat:
|
||||
|
||||
+1
-4
@@ -95,10 +95,7 @@ const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead a
|
||||
|
||||
func (mc *okHandler) handleInFileRequest(name string) (err error) {
|
||||
var rdr io.Reader
|
||||
packetSize := defaultPacketSize
|
||||
if mc.maxWriteSize < packetSize {
|
||||
packetSize = mc.maxWriteSize
|
||||
}
|
||||
packetSize := min(mc.maxWriteSize, defaultPacketSize)
|
||||
|
||||
if idx := strings.Index(name, "Reader::"); idx == 0 || (idx > 0 && name[idx-1] == '/') { // io.Reader
|
||||
// The server might return an an absolute path. See issue #355.
|
||||
|
||||
+205
-138
@@ -179,20 +179,22 @@ func (mc *mysqlConn) writePacket(data []byte) error {
|
||||
******************************************************************************/
|
||||
|
||||
// Handshake Initialization Packet
|
||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
|
||||
func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) {
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html
|
||||
// https://mariadb.com/kb/en/connection/#initial-handshake-packet
|
||||
func (mc *mysqlConn) readHandshakePacket() (data []byte, capabilities capabilityFlag, extendedCapabilities extendedCapabilityFlag, plugin string, err error) {
|
||||
data, err = mc.readPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if data[0] == iERR {
|
||||
return nil, "", mc.handleErrorPacket(data)
|
||||
err = mc.handleErrorPacket(data)
|
||||
return
|
||||
}
|
||||
|
||||
// protocol version [1 byte]
|
||||
if data[0] < minProtocolVersion {
|
||||
return nil, "", fmt.Errorf(
|
||||
return nil, 0, 0, "", fmt.Errorf(
|
||||
"unsupported protocol version %d. Version %d or higher is required",
|
||||
data[0],
|
||||
minProtocolVersion,
|
||||
@@ -210,15 +212,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
|
||||
pos += 8 + 1
|
||||
|
||||
// capability flags (lower 2 bytes) [2 bytes]
|
||||
mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
||||
if mc.flags&clientProtocol41 == 0 {
|
||||
return nil, "", ErrOldProtocol
|
||||
capabilities = capabilityFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
||||
if capabilities&clientProtocol41 == 0 {
|
||||
return nil, capabilities, 0, "", ErrOldProtocol
|
||||
}
|
||||
if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil {
|
||||
if capabilities&clientSSL == 0 && mc.cfg.TLS != nil {
|
||||
if mc.cfg.AllowFallbackToPlaintext {
|
||||
mc.cfg.TLS = nil
|
||||
} else {
|
||||
return nil, "", ErrNoTLS
|
||||
return nil, capabilities, 0, "", ErrNoTLS
|
||||
}
|
||||
}
|
||||
pos += 2
|
||||
@@ -228,11 +230,16 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
|
||||
// status flags [2 bytes]
|
||||
pos += 3
|
||||
// capability flags (upper 2 bytes) [2 bytes]
|
||||
mc.flags |= clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
|
||||
capabilities |= capabilityFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
|
||||
pos += 2
|
||||
// length of auth-plugin-data [1 byte]
|
||||
// reserved (all [00]) [10 bytes]
|
||||
pos += 11
|
||||
// reserved (all [00]) [6 bytes]
|
||||
pos += 7
|
||||
if capabilities&clientMySQL == 0 {
|
||||
// MariaDB server extended flag
|
||||
extendedCapabilities = extendedCapabilityFlag(binary.LittleEndian.Uint32(data[pos : pos+4]))
|
||||
}
|
||||
pos += 4
|
||||
|
||||
// second part of the password cipher [minimum 13 bytes],
|
||||
// where len=MAX(13, length of auth-plugin-data - 8)
|
||||
@@ -260,82 +267,72 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
|
||||
// make a memory safe copy of the cipher slice
|
||||
var b [20]byte
|
||||
copy(b[:], authData)
|
||||
return b[:], plugin, nil
|
||||
return b[:], capabilities, extendedCapabilities, plugin, nil
|
||||
}
|
||||
|
||||
// make a memory safe copy of the cipher slice
|
||||
var b [8]byte
|
||||
copy(b[:], authData)
|
||||
return b[:], plugin, nil
|
||||
return b[:], capabilities, 0, plugin, nil
|
||||
}
|
||||
|
||||
// Client Authentication Packet
|
||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
|
||||
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error {
|
||||
// Adjust client flags based on server support
|
||||
clientFlags := clientProtocol41 |
|
||||
clientSecureConn |
|
||||
clientLongPassword |
|
||||
clientTransactions |
|
||||
clientLocalFiles |
|
||||
clientPluginAuth |
|
||||
clientMultiResults |
|
||||
mc.flags&clientConnectAttrs |
|
||||
mc.flags&clientLongFlag
|
||||
// initCapabilities initializes the capabilities based on server support and configuration
|
||||
func (mc *mysqlConn) initCapabilities(serverCapabilities capabilityFlag, serverExtCapabilities extendedCapabilityFlag, cfg *Config) {
|
||||
clientCapabilities :=
|
||||
clientMySQL |
|
||||
clientLongFlag |
|
||||
clientProtocol41 |
|
||||
clientSecureConn |
|
||||
clientTransactions |
|
||||
clientPluginAuthLenEncClientData |
|
||||
clientLocalFiles |
|
||||
clientPluginAuth |
|
||||
clientMultiResults |
|
||||
clientConnectAttrs |
|
||||
clientDeprecateEOF
|
||||
|
||||
sendConnectAttrs := mc.flags&clientConnectAttrs != 0
|
||||
|
||||
if mc.cfg.ClientFoundRows {
|
||||
clientFlags |= clientFoundRows
|
||||
if cfg.ClientFoundRows {
|
||||
clientCapabilities |= clientFoundRows
|
||||
}
|
||||
if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
|
||||
clientFlags |= clientCompress
|
||||
if cfg.compress {
|
||||
clientCapabilities |= clientCompress
|
||||
}
|
||||
// To enable TLS / SSL
|
||||
if mc.cfg.TLS != nil {
|
||||
clientFlags |= clientSSL
|
||||
clientCapabilities |= clientSSL
|
||||
}
|
||||
|
||||
if mc.cfg.MultiStatements {
|
||||
clientFlags |= clientMultiStatements
|
||||
clientCapabilities |= clientMultiStatements
|
||||
}
|
||||
if n := len(cfg.DBName); n > 0 {
|
||||
clientCapabilities |= clientConnectWithDB
|
||||
}
|
||||
|
||||
// encode length of the auth plugin data
|
||||
var authRespLEIBuf [9]byte
|
||||
authRespLen := len(authResp)
|
||||
authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen))
|
||||
if len(authRespLEI) > 1 {
|
||||
// if the length can not be written in 1 byte, it must be written as a
|
||||
// length encoded integer
|
||||
clientFlags |= clientPluginAuthLenEncClientData
|
||||
}
|
||||
// only keep client capabilities that server have
|
||||
mc.capabilities = clientCapabilities & serverCapabilities
|
||||
|
||||
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
|
||||
// set MariaDB extended clientCacheMetadata capability if server support it
|
||||
mc.extCapabilities = clientCacheMetadata & serverExtCapabilities
|
||||
}
|
||||
|
||||
// To specify a db name
|
||||
if n := len(mc.cfg.DBName); n > 0 {
|
||||
clientFlags |= clientConnectWithDB
|
||||
pktLen += n + 1
|
||||
}
|
||||
|
||||
// encode length of the connection attributes
|
||||
var connAttrsLEI []byte
|
||||
if sendConnectAttrs {
|
||||
var connAttrsLEIBuf [9]byte
|
||||
connAttrsLen := len(mc.connector.encodedAttributes)
|
||||
connAttrsLEI = appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen))
|
||||
pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes)
|
||||
}
|
||||
|
||||
// Calculate packet length and get buffer with that size
|
||||
data, err := mc.buf.takeBuffer(pktLen + 4)
|
||||
// Client Authentication Packet
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_response.html
|
||||
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error {
|
||||
// packet header 4
|
||||
// capabilities 4
|
||||
// maxPacketSize 4
|
||||
// collation id 1
|
||||
// filler 23
|
||||
data, err := mc.buf.takeSmallBuffer(4*3 + 24)
|
||||
if err != nil {
|
||||
mc.cleanup()
|
||||
return err
|
||||
}
|
||||
_ = data[4*3+23] // boundery check
|
||||
|
||||
// ClientFlags [32 bit]
|
||||
binary.LittleEndian.PutUint32(data[4:], uint32(clientFlags))
|
||||
// clientCapabilities [32 bit]
|
||||
binary.LittleEndian.PutUint32(data[4:], uint32(mc.capabilities))
|
||||
|
||||
// MaxPacketSize [32 bit] (none)
|
||||
binary.LittleEndian.PutUint32(data[8:], 0)
|
||||
@@ -353,16 +350,26 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
|
||||
}
|
||||
|
||||
// Filler [23 bytes] (all 0x00)
|
||||
// or filler 19bytes + mariadb extCapabilities
|
||||
pos := 13
|
||||
for ; pos < 13+23; pos++ {
|
||||
data[pos] = 0
|
||||
if mc.capabilities&clientMySQL == 0 {
|
||||
for ; pos < 13+19; pos++ {
|
||||
data[pos] = 0
|
||||
}
|
||||
// MariaDB Extended Capabilities
|
||||
binary.LittleEndian.PutUint32(data[13+19:], uint32(mc.extCapabilities))
|
||||
} else {
|
||||
for ; pos < 13+23; pos++ {
|
||||
data[pos] = 0
|
||||
}
|
||||
}
|
||||
|
||||
// SSL Connection Request Packet
|
||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_ssl_request.html
|
||||
// https://mariadb.com/kb/en/connection/#sslrequest-packet
|
||||
if mc.cfg.TLS != nil {
|
||||
// Send TLS / SSL request packet
|
||||
if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
|
||||
if err := mc.writePacket(data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -379,37 +386,35 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
|
||||
|
||||
// User [null terminated string]
|
||||
if len(mc.cfg.User) > 0 {
|
||||
pos += copy(data[pos:], mc.cfg.User)
|
||||
data = append(data, mc.cfg.User...)
|
||||
}
|
||||
data[pos] = 0x00
|
||||
pos++
|
||||
data = append(data, 0)
|
||||
|
||||
// Auth Data [length encoded integer]
|
||||
pos += copy(data[pos:], authRespLEI)
|
||||
pos += copy(data[pos:], authResp)
|
||||
data = appendLengthEncodedInteger(data, uint64(len(authResp)))
|
||||
data = append(data, authResp...)
|
||||
|
||||
// Databasename [null terminated string]
|
||||
if len(mc.cfg.DBName) > 0 {
|
||||
pos += copy(data[pos:], mc.cfg.DBName)
|
||||
data[pos] = 0x00
|
||||
pos++
|
||||
// Database name [null terminated string]
|
||||
if mc.capabilities&clientConnectWithDB != 0 {
|
||||
data = append(data, mc.cfg.DBName...)
|
||||
data = append(data, 0)
|
||||
}
|
||||
|
||||
pos += copy(data[pos:], plugin)
|
||||
data[pos] = 0x00
|
||||
pos++
|
||||
data = append(data, plugin...)
|
||||
data = append(data, 0)
|
||||
|
||||
// Connection Attributes
|
||||
if sendConnectAttrs {
|
||||
pos += copy(data[pos:], connAttrsLEI)
|
||||
pos += copy(data[pos:], []byte(mc.connector.encodedAttributes))
|
||||
if mc.capabilities&clientConnectAttrs != 0 {
|
||||
connAttrsLen := len(mc.connector.encodedAttributes)
|
||||
data = appendLengthEncodedInteger(data, uint64(connAttrsLen))
|
||||
data = append(data, mc.connector.encodedAttributes...)
|
||||
}
|
||||
|
||||
// Send Auth packet
|
||||
return mc.writePacket(data[:pos])
|
||||
return mc.writePacket(data)
|
||||
}
|
||||
|
||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_response.html
|
||||
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
|
||||
pktLen := 4 + len(authData)
|
||||
data, err := mc.buf.takeBuffer(pktLen)
|
||||
@@ -511,7 +516,7 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
|
||||
|
||||
case iEOF:
|
||||
if len(data) == 1 {
|
||||
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_old_auth_switch_request.html
|
||||
return nil, "mysql_old_password", nil
|
||||
}
|
||||
pluginEndIndex := bytes.IndexByte(data, 0x00)
|
||||
@@ -545,36 +550,41 @@ func (mc *okHandler) readResultOK() error {
|
||||
|
||||
// Result Set Header Packet
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response.html
|
||||
func (mc *okHandler) readResultSetHeaderPacket() (int, error) {
|
||||
func (mc *okHandler) readResultSetHeaderPacket() (int, bool, error) {
|
||||
// handleOkPacket replaces both values; other cases leave the values unchanged.
|
||||
mc.result.affectedRows = append(mc.result.affectedRows, 0)
|
||||
mc.result.insertIds = append(mc.result.insertIds, 0)
|
||||
|
||||
data, err := mc.conn().readPacket()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return 0, false, err
|
||||
}
|
||||
|
||||
switch data[0] {
|
||||
case iOK:
|
||||
return 0, mc.handleOkPacket(data)
|
||||
return 0, false, mc.handleOkPacket(data)
|
||||
|
||||
case iERR:
|
||||
return 0, mc.conn().handleErrorPacket(data)
|
||||
return 0, false, mc.conn().handleErrorPacket(data)
|
||||
|
||||
case iLocalInFile:
|
||||
return 0, mc.handleInFileRequest(string(data[1:]))
|
||||
return 0, false, mc.handleInFileRequest(string(data[1:]))
|
||||
}
|
||||
|
||||
// column count
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset.html
|
||||
num, _, _ := readLengthEncodedInteger(data)
|
||||
// https://mariadb.com/kb/en/result-set-packets/#column-count-packet
|
||||
num, _, len := readLengthEncodedInteger(data)
|
||||
|
||||
if mc.extCapabilities&clientCacheMetadata != 0 {
|
||||
return int(num), data[len] == 0x01, nil
|
||||
}
|
||||
// ignore remaining data in the packet. see #1478.
|
||||
return int(num), nil
|
||||
return int(num), true, nil
|
||||
}
|
||||
|
||||
// Error Packet
|
||||
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_err_packet.html
|
||||
func (mc *mysqlConn) handleErrorPacket(data []byte) error {
|
||||
if data[0] != iERR {
|
||||
return ErrMalformPkt
|
||||
@@ -656,7 +666,7 @@ func (mc *mysqlConn) clearResult() *okHandler {
|
||||
}
|
||||
|
||||
// Ok Packet
|
||||
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_ok_packet.html
|
||||
func (mc *okHandler) handleOkPacket(data []byte) error {
|
||||
var n, m int
|
||||
var affectedRows, insertId uint64
|
||||
@@ -690,24 +700,19 @@ func (mc *okHandler) handleOkPacket(data []byte) error {
|
||||
}
|
||||
|
||||
// Read Packets as Field Packets until EOF-Packet or an Error appears
|
||||
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
|
||||
func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset_column_definition.html#sect_protocol_com_query_response_text_resultset_column_definition_41
|
||||
func (mc *mysqlConn) readColumns(count int, old []mysqlField) ([]mysqlField, error) {
|
||||
columns := make([]mysqlField, count)
|
||||
if len(old) != count {
|
||||
old = nil
|
||||
}
|
||||
|
||||
for i := 0; ; i++ {
|
||||
for i := range count {
|
||||
data, err := mc.readPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// EOF Packet
|
||||
if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
|
||||
if i == count {
|
||||
return columns, nil
|
||||
}
|
||||
return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns))
|
||||
}
|
||||
|
||||
// Catalog
|
||||
pos, err := skipLengthEncodedString(data)
|
||||
if err != nil {
|
||||
@@ -728,7 +733,12 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
|
||||
return nil, err
|
||||
}
|
||||
pos += n
|
||||
columns[i].tableName = string(tableName)
|
||||
if old != nil && old[i].tableName == string(tableName) {
|
||||
// avoid allocating new string
|
||||
columns[i].tableName = old[i].tableName
|
||||
} else {
|
||||
columns[i].tableName = string(tableName)
|
||||
}
|
||||
} else {
|
||||
n, err = skipLengthEncodedString(data[pos:])
|
||||
if err != nil {
|
||||
@@ -749,7 +759,12 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
columns[i].name = string(name)
|
||||
if old != nil && old[i].name == string(name) {
|
||||
// avoid allocating new string
|
||||
columns[i].name = old[i].name
|
||||
} else {
|
||||
columns[i].name = string(name)
|
||||
}
|
||||
pos += n
|
||||
|
||||
// Original name [len coded string]
|
||||
@@ -780,17 +795,17 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
|
||||
|
||||
// Decimals [uint8]
|
||||
columns[i].decimals = data[pos]
|
||||
//pos++
|
||||
|
||||
// Default value [len coded binary]
|
||||
//if pos < len(data) {
|
||||
// defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
|
||||
//}
|
||||
}
|
||||
|
||||
// skip EOF packet if client does not support deprecateEOF
|
||||
if err := mc.skipEof(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
// Read Packets as Field Packets until EOF-Packet or an Error appears
|
||||
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset_row.html
|
||||
func (rows *textRows) readRow(dest []driver.Value) error {
|
||||
mc := rows.mc
|
||||
|
||||
@@ -804,9 +819,20 @@ func (rows *textRows) readRow(dest []driver.Value) error {
|
||||
}
|
||||
|
||||
// EOF Packet
|
||||
if data[0] == iEOF && len(data) == 5 {
|
||||
// server_status [2 bytes]
|
||||
rows.mc.status = readStatus(data[3:])
|
||||
// text row packets may starts with LengthEncodedString.
|
||||
// In such case, 0xFE can mean string larger than 0xffffff.
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_dt_integers.html#sect_protocol_basic_dt_int_le
|
||||
if data[0] == iEOF && len(data) <= 0xffffff {
|
||||
if mc.capabilities&clientDeprecateEOF == 0 {
|
||||
// Deprecated EOF packet
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_eof_packet.html
|
||||
mc.status = readStatus(data[3:])
|
||||
} else {
|
||||
// Ok Packet with an 0xFE header
|
||||
_, _, n := readLengthEncodedInteger(data[1:]) // affected_rows
|
||||
_, _, m := readLengthEncodedInteger(data[1+n:]) // last_insert_id
|
||||
mc.status = readStatus(data[1+n+m:])
|
||||
}
|
||||
rows.rs.done = true
|
||||
if !rows.HasNextResultSet() {
|
||||
rows.mc = nil
|
||||
@@ -880,8 +906,34 @@ func (rows *textRows) readRow(dest []driver.Value) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
|
||||
func (mc *mysqlConn) readUntilEOF() error {
|
||||
func (mc *mysqlConn) skipPackets(n int) error {
|
||||
for range n {
|
||||
if _, err := mc.readPacket(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// skips EOF packet after n * ColumnDefinition packets when clientDeprecateEOF is not set
|
||||
func (mc *mysqlConn) skipEof() error {
|
||||
if mc.capabilities&clientDeprecateEOF == 0 {
|
||||
if _, err := mc.readPacket(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) skipColumns(n int) error {
|
||||
if err := mc.skipPackets(n); err != nil {
|
||||
return err
|
||||
}
|
||||
return mc.skipEof()
|
||||
}
|
||||
|
||||
// Reads Packets until EOF-Packet or an Error appears.
|
||||
func (mc *mysqlConn) skipRows() error {
|
||||
for {
|
||||
data, err := mc.readPacket()
|
||||
if err != nil {
|
||||
@@ -892,10 +944,20 @@ func (mc *mysqlConn) readUntilEOF() error {
|
||||
case iERR:
|
||||
return mc.handleErrorPacket(data)
|
||||
case iEOF:
|
||||
if len(data) == 5 {
|
||||
mc.status = readStatus(data[3:])
|
||||
// text row packets may starts with LengthEncodedString.
|
||||
// In such case, 0xFE can mean string larger than 0xffffff.
|
||||
if len(data) <= 0xffffff {
|
||||
if mc.capabilities&clientDeprecateEOF == 0 {
|
||||
// EOF packet
|
||||
mc.status = readStatus(data[3:])
|
||||
} else {
|
||||
// OK packet with an 0xFE header
|
||||
_, _, n := readLengthEncodedInteger(data[1:]) // affected_rows
|
||||
_, _, m := readLengthEncodedInteger(data[1+n:]) // last_insert_id
|
||||
mc.status = readStatus(data[1+n+m:])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -905,7 +967,7 @@ func (mc *mysqlConn) readUntilEOF() error {
|
||||
******************************************************************************/
|
||||
|
||||
// Prepare Result Packets
|
||||
// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_prepare.html#sect_protocol_com_stmt_prepare_response
|
||||
func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
|
||||
data, err := stmt.mc.readPacket()
|
||||
if err == nil {
|
||||
@@ -932,7 +994,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_send_long_data.html
|
||||
func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
|
||||
maxLen := stmt.mc.maxAllowedPacket - 1
|
||||
pktLen := maxLen
|
||||
@@ -979,7 +1041,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
|
||||
}
|
||||
|
||||
// Execute Prepared Statement
|
||||
// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
|
||||
func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||
if len(args) != stmt.paramCount {
|
||||
return fmt.Errorf(
|
||||
@@ -993,10 +1055,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||
mc := stmt.mc
|
||||
|
||||
// Determine threshold dynamically to avoid packet size shortage.
|
||||
longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
|
||||
if longDataSize < 64 {
|
||||
longDataSize = 64
|
||||
}
|
||||
longDataSize := max(mc.maxAllowedPacket/(stmt.paramCount+1), 64)
|
||||
|
||||
// Reset packet-sequence
|
||||
mc.resetSequence()
|
||||
@@ -1185,17 +1244,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||
// mc.affectedRows and mc.insertIds.
|
||||
func (mc *okHandler) discardResults() error {
|
||||
for mc.status&statusMoreResultsExists != 0 {
|
||||
resLen, err := mc.readResultSetHeaderPacket()
|
||||
resLen, _, err := mc.readResultSetHeaderPacket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resLen > 0 {
|
||||
// columns
|
||||
if err := mc.conn().readUntilEOF(); err != nil {
|
||||
if err := mc.conn().skipColumns(resLen); err != nil {
|
||||
return err
|
||||
}
|
||||
// rows
|
||||
if err := mc.conn().readUntilEOF(); err != nil {
|
||||
if err := mc.conn().skipRows(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -1203,7 +1262,7 @@ func (mc *okHandler) discardResults() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_binary_resultset.html#sect_protocol_binary_resultset_row
|
||||
func (rows *binaryRows) readRow(dest []driver.Value) error {
|
||||
data, err := rows.mc.readPacket()
|
||||
if err != nil {
|
||||
@@ -1212,9 +1271,17 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
||||
|
||||
// packet indicator [1 byte]
|
||||
if data[0] != iOK {
|
||||
// EOF Packet
|
||||
if data[0] == iEOF && len(data) == 5 {
|
||||
rows.mc.status = readStatus(data[3:])
|
||||
// EOF/OK Packet
|
||||
if data[0] == iEOF {
|
||||
if rows.mc.capabilities&clientDeprecateEOF == 0 {
|
||||
// EOF packet
|
||||
rows.mc.status = readStatus(data[3:])
|
||||
} else {
|
||||
// OK Packet with an 0xFE header
|
||||
_, _, n := readLengthEncodedInteger(data[1:])
|
||||
_, _, m := readLengthEncodedInteger(data[1+n:])
|
||||
rows.mc.status = readStatus(data[1+n+m:])
|
||||
}
|
||||
rows.rs.done = true
|
||||
if !rows.HasNextResultSet() {
|
||||
rows.mc = nil
|
||||
|
||||
+4
-2
@@ -8,6 +8,8 @@
|
||||
|
||||
package mysql
|
||||
|
||||
import "slices"
|
||||
|
||||
import "database/sql/driver"
|
||||
|
||||
// Result exposes data not available through *connection.Result.
|
||||
@@ -42,9 +44,9 @@ func (res *mysqlResult) RowsAffected() (int64, error) {
|
||||
}
|
||||
|
||||
func (res *mysqlResult) AllLastInsertIds() []int64 {
|
||||
return append([]int64{}, res.insertIds...) // defensive copy
|
||||
return slices.Clone(res.insertIds) // defensive copy
|
||||
}
|
||||
|
||||
func (res *mysqlResult) AllRowsAffected() []int64 {
|
||||
return append([]int64{}, res.affectedRows...) // defensive copy
|
||||
return slices.Clone(res.affectedRows) // defensive copy
|
||||
}
|
||||
|
||||
+5
-5
@@ -113,7 +113,7 @@ func (rows *mysqlRows) Close() (err error) {
|
||||
|
||||
// Remove unread packets from stream
|
||||
if !rows.rs.done {
|
||||
err = mc.readUntilEOF()
|
||||
err = mc.skipRows()
|
||||
}
|
||||
if err == nil {
|
||||
handleOk := mc.clearResult()
|
||||
@@ -143,7 +143,7 @@ func (rows *mysqlRows) nextResultSet() (int, error) {
|
||||
|
||||
// Remove unread packets from stream
|
||||
if !rows.rs.done {
|
||||
if err := rows.mc.readUntilEOF(); err != nil {
|
||||
if err := rows.mc.skipRows(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
rows.rs.done = true
|
||||
@@ -156,7 +156,7 @@ func (rows *mysqlRows) nextResultSet() (int, error) {
|
||||
rows.rs = resultSet{}
|
||||
// rows.mc.affectedRows and rows.mc.insertIds accumulate on each call to
|
||||
// nextResultSet.
|
||||
resLen, err := rows.mc.resultUnchanged().readResultSetHeaderPacket()
|
||||
resLen, _, err := rows.mc.resultUnchanged().readResultSetHeaderPacket()
|
||||
if err != nil {
|
||||
// Clean up about multi-results flag
|
||||
rows.rs.done = true
|
||||
@@ -186,7 +186,7 @@ func (rows *binaryRows) NextResultSet() error {
|
||||
return err
|
||||
}
|
||||
|
||||
rows.rs.columns, err = rows.mc.readColumns(resLen)
|
||||
rows.rs.columns, err = rows.mc.readColumns(resLen, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -208,7 +208,7 @@ func (rows *textRows) NextResultSet() (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
rows.rs.columns, err = rows.mc.readColumns(resLen)
|
||||
rows.rs.columns, err = rows.mc.readColumns(resLen, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
+26
-8
@@ -20,6 +20,7 @@ type mysqlStmt struct {
|
||||
mc *mysqlConn
|
||||
id uint32
|
||||
paramCount int
|
||||
columns []mysqlField
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) Close() error {
|
||||
@@ -64,19 +65,26 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
handleOk := stmt.mc.clearResult()
|
||||
|
||||
// Read Result
|
||||
resLen, err := handleOk.readResultSetHeaderPacket()
|
||||
resLen, metadataFollows, err := handleOk.readResultSetHeaderPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resLen > 0 {
|
||||
// Columns
|
||||
if err = mc.readUntilEOF(); err != nil {
|
||||
return nil, err
|
||||
if metadataFollows && stmt.mc.extCapabilities&clientCacheMetadata != 0 {
|
||||
// we can not skip column metadata because next stmt.Query() may use it.
|
||||
if stmt.columns, err = mc.readColumns(resLen, stmt.columns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err = mc.skipColumns(resLen); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Rows
|
||||
if err := mc.readUntilEOF(); err != nil {
|
||||
if err = mc.skipRows(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -107,7 +115,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
|
||||
|
||||
// Read Result
|
||||
handleOk := stmt.mc.clearResult()
|
||||
resLen, err := handleOk.readResultSetHeaderPacket()
|
||||
resLen, metadataFollows, err := handleOk.readResultSetHeaderPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -116,7 +124,17 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
|
||||
|
||||
if resLen > 0 {
|
||||
rows.mc = mc
|
||||
rows.rs.columns, err = mc.readColumns(resLen)
|
||||
if metadataFollows {
|
||||
if rows.rs.columns, err = mc.readColumns(resLen, stmt.columns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stmt.columns = rows.rs.columns
|
||||
} else {
|
||||
if err = mc.skipEof(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows.rs.columns = stmt.columns
|
||||
}
|
||||
} else {
|
||||
rows.rs.done = true
|
||||
|
||||
@@ -131,7 +149,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
|
||||
return rows, err
|
||||
}
|
||||
|
||||
var jsonType = reflect.TypeOf(json.RawMessage{})
|
||||
var jsonType = reflect.TypeFor[json.RawMessage]()
|
||||
|
||||
type converter struct{}
|
||||
|
||||
@@ -193,7 +211,7 @@ func (c converter) ConvertValue(v any) (driver.Value, error) {
|
||||
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
|
||||
}
|
||||
|
||||
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
|
||||
var valuerReflectType = reflect.TypeFor[driver.Valuer]()
|
||||
|
||||
// callValuerValue returns vr.Value(), with one exception:
|
||||
// If vr.Value is an auto-generated method on a pointer type and the
|
||||
|
||||
+65
-90
@@ -182,7 +182,7 @@ func parseDateTime(b []byte, loc *time.Location) (time.Time, error) {
|
||||
|
||||
func parseByteYear(b []byte) (int, error) {
|
||||
year, n := 0, 1000
|
||||
for i := 0; i < 4; i++ {
|
||||
for i := range 4 {
|
||||
v, err := bToi(b[i])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -207,7 +207,7 @@ func parseByte2Digits(b1, b2 byte) (int, error) {
|
||||
|
||||
func parseByteNanoSec(b []byte) (int, error) {
|
||||
ns, digit := 0, 100000 // max is 6-digits
|
||||
for i := 0; i < len(b); i++ {
|
||||
for i := range b {
|
||||
v, err := bToi(b[i])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -625,108 +625,80 @@ func reserveBuffer(buf []byte, appendSize int) []byte {
|
||||
return buf[:newSize]
|
||||
}
|
||||
|
||||
// escapeBytesBackslash escapes []byte with backslashes (\)
|
||||
// This escapes the contents of a string (provided as []byte) by adding backslashes before special
|
||||
// characters, and turning others into specific escape sequences, such as
|
||||
// turning newlines into \n and null bytes into \0.
|
||||
// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932
|
||||
func escapeBytesBackslash(buf, v []byte) []byte {
|
||||
pos := len(buf)
|
||||
buf = reserveBuffer(buf, len(v)*2)
|
||||
// Lookup table for backslash escapes (used for both string and bytes)
|
||||
var backslashEscapeTable [256]byte
|
||||
|
||||
for _, c := range v {
|
||||
switch c {
|
||||
case '\x00':
|
||||
buf[pos+1] = '0'
|
||||
buf[pos] = '\\'
|
||||
pos += 2
|
||||
case '\n':
|
||||
buf[pos+1] = 'n'
|
||||
buf[pos] = '\\'
|
||||
pos += 2
|
||||
case '\r':
|
||||
buf[pos+1] = 'r'
|
||||
buf[pos] = '\\'
|
||||
pos += 2
|
||||
case '\x1a':
|
||||
buf[pos+1] = 'Z'
|
||||
buf[pos] = '\\'
|
||||
pos += 2
|
||||
case '\'':
|
||||
buf[pos+1] = '\''
|
||||
buf[pos] = '\\'
|
||||
pos += 2
|
||||
case '"':
|
||||
buf[pos+1] = '"'
|
||||
buf[pos] = '\\'
|
||||
pos += 2
|
||||
case '\\':
|
||||
buf[pos+1] = '\\'
|
||||
buf[pos] = '\\'
|
||||
pos += 2
|
||||
default:
|
||||
buf[pos] = c
|
||||
pos++
|
||||
}
|
||||
}
|
||||
|
||||
return buf[:pos]
|
||||
func init() {
|
||||
backslashEscapeTable['\x00'] = '0'
|
||||
backslashEscapeTable['\n'] = 'n'
|
||||
backslashEscapeTable['\r'] = 'r'
|
||||
backslashEscapeTable['\x1a'] = 'Z'
|
||||
backslashEscapeTable['\''] = '\''
|
||||
backslashEscapeTable['"'] = '"'
|
||||
backslashEscapeTable['\\'] = '\\'
|
||||
}
|
||||
|
||||
// escapeStringBackslash is similar to escapeBytesBackslash but for string.
|
||||
func escapeStringBackslash(buf []byte, v string) []byte {
|
||||
pos := len(buf)
|
||||
buf = reserveBuffer(buf, len(v)*2)
|
||||
|
||||
buf = reserveBuffer(buf, len(v)*2+2)
|
||||
buf[pos] = '\''
|
||||
pos++
|
||||
for i := 0; i < len(v); i++ {
|
||||
c := v[i]
|
||||
switch c {
|
||||
case '\x00':
|
||||
buf[pos+1] = '0'
|
||||
if esc := backslashEscapeTable[c]; esc != 0 {
|
||||
buf[pos+1] = esc
|
||||
buf[pos] = '\\'
|
||||
pos += 2
|
||||
case '\n':
|
||||
buf[pos+1] = 'n'
|
||||
buf[pos] = '\\'
|
||||
pos += 2
|
||||
case '\r':
|
||||
buf[pos+1] = 'r'
|
||||
buf[pos] = '\\'
|
||||
pos += 2
|
||||
case '\x1a':
|
||||
buf[pos+1] = 'Z'
|
||||
buf[pos] = '\\'
|
||||
pos += 2
|
||||
case '\'':
|
||||
buf[pos+1] = '\''
|
||||
buf[pos] = '\\'
|
||||
pos += 2
|
||||
case '"':
|
||||
buf[pos+1] = '"'
|
||||
buf[pos] = '\\'
|
||||
pos += 2
|
||||
case '\\':
|
||||
buf[pos+1] = '\\'
|
||||
buf[pos] = '\\'
|
||||
pos += 2
|
||||
default:
|
||||
} else {
|
||||
buf[pos] = c
|
||||
pos++
|
||||
}
|
||||
}
|
||||
|
||||
buf[pos] = '\''
|
||||
pos++
|
||||
return buf[:pos]
|
||||
}
|
||||
|
||||
// escapeBytesQuotes escapes apostrophes in []byte by doubling them up.
|
||||
// This escapes the contents of a string by doubling up any apostrophes that
|
||||
// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in
|
||||
// effect on the server.
|
||||
// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038
|
||||
func escapeBytesQuotes(buf, v []byte) []byte {
|
||||
// escapeBytesBackslash appends _binary'...' or '...' with backslash escaping for bytes.
|
||||
func escapeBytesBackslash(buf, v []byte, binary bool) []byte {
|
||||
pos := len(buf)
|
||||
buf = reserveBuffer(buf, len(v)*2)
|
||||
if binary {
|
||||
buf = reserveBuffer(buf, len(v)*2+9)
|
||||
copy(buf[pos:], []byte("_binary'"))
|
||||
pos += 8
|
||||
} else {
|
||||
buf = reserveBuffer(buf, len(v)*2+2)
|
||||
buf[pos] = '\''
|
||||
pos++
|
||||
}
|
||||
for _, c := range v {
|
||||
if esc := backslashEscapeTable[c]; esc != 0 {
|
||||
buf[pos+1] = esc
|
||||
buf[pos] = '\\'
|
||||
pos += 2
|
||||
} else {
|
||||
buf[pos] = c
|
||||
pos++
|
||||
}
|
||||
}
|
||||
buf[pos] = '\''
|
||||
pos++
|
||||
return buf[:pos]
|
||||
}
|
||||
|
||||
// escapeBytesQuotes appends _binary'...' or '...' with single-quote escaping for bytes.
|
||||
func escapeBytesQuotes(buf, v []byte, binary bool) []byte {
|
||||
pos := len(buf)
|
||||
if binary {
|
||||
buf = reserveBuffer(buf, len(v)*2+9)
|
||||
copy(buf[pos:], []byte("_binary'"))
|
||||
pos += 8
|
||||
} else {
|
||||
buf = reserveBuffer(buf, len(v)*2+2)
|
||||
buf[pos] = '\''
|
||||
pos++
|
||||
}
|
||||
for _, c := range v {
|
||||
if c == '\'' {
|
||||
buf[pos+1] = '\''
|
||||
@@ -737,16 +709,18 @@ func escapeBytesQuotes(buf, v []byte) []byte {
|
||||
pos++
|
||||
}
|
||||
}
|
||||
|
||||
buf[pos] = '\''
|
||||
pos++
|
||||
return buf[:pos]
|
||||
}
|
||||
|
||||
// escapeStringQuotes is similar to escapeBytesQuotes but for string.
|
||||
func escapeStringQuotes(buf []byte, v string) []byte {
|
||||
pos := len(buf)
|
||||
buf = reserveBuffer(buf, len(v)*2)
|
||||
|
||||
for i := 0; i < len(v); i++ {
|
||||
buf = reserveBuffer(buf, len(v)*2+2)
|
||||
buf[pos] = '\''
|
||||
pos++
|
||||
for i := range len(v) {
|
||||
c := v[i]
|
||||
if c == '\'' {
|
||||
buf[pos+1] = '\''
|
||||
@@ -757,7 +731,8 @@ func escapeStringQuotes(buf []byte, v string) []byte {
|
||||
pos++
|
||||
}
|
||||
}
|
||||
|
||||
buf[pos] = '\''
|
||||
pos++
|
||||
return buf[:pos]
|
||||
}
|
||||
|
||||
|
||||
+33
-15
@@ -696,7 +696,7 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) (au
|
||||
// If we are here we have an auth callout defined and we have failed auth so far
|
||||
// so we will callout to our auth backend for processing.
|
||||
if !skip {
|
||||
authorized, reason = s.processClientOrLeafCallout(c, opts, proxyRequired, trustedProxy)
|
||||
authorized, reason = s.processClientOrLeafCallout(c, opts, proxyRequired, trustedProxy, ujwt)
|
||||
}
|
||||
// Check if we are authorized and in the auth callout account, and if so add in deny publish permissions for the auth subject.
|
||||
if authorized {
|
||||
@@ -797,26 +797,42 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) (au
|
||||
token = opts.Authorization
|
||||
}
|
||||
|
||||
// Check if we have trustedKeys defined in the server. If so we require a user jwt.
|
||||
if s.trustedKeys != nil {
|
||||
ujwt = c.opts.JWT
|
||||
if ujwt == _EMPTY_ && c.isMqtt() {
|
||||
// For MQTT, we pass the password as the JWT too, but do so here so it's not
|
||||
// publicly exposed in the client options if it isn't a JWT.
|
||||
// MQTT can carry JWTs in the password field. Reconstruct it here for auth
|
||||
// processing and auth callout, but do not populate c.opts.JWT yet or it would
|
||||
// be exposed through monitoring and advisory paths even when the password is
|
||||
// not actually a JWT.
|
||||
if ujwt == _EMPTY_ && c.isMqtt() && c.opts.JWT == _EMPTY_ {
|
||||
// Don't set juc here, leave that to the next s.trustedKeys != nil block,
|
||||
// so that we don't try to trust a JWT when we aren't in operator mode. We
|
||||
// will allow it to be passed through auth callout though.
|
||||
if _, err := jwt.DecodeUserClaims(c.opts.Password); err == nil {
|
||||
ujwt = c.opts.Password
|
||||
}
|
||||
if ujwt == _EMPTY_ && opts.DefaultSentinel != _EMPTY_ {
|
||||
c.opts.JWT = opts.DefaultSentinel
|
||||
ujwt = c.opts.JWT
|
||||
}
|
||||
|
||||
// Check if we have trustedKeys defined in the server. If so we require a user jwt.
|
||||
if s.trustedKeys != nil {
|
||||
if ujwt == _EMPTY_ {
|
||||
// Need to be sure that it's a NATS JWT, otherwise we will not correctly
|
||||
// attempt the default sentinel below.
|
||||
if _, err = jwt.DecodeUserClaims(c.opts.JWT); err == nil {
|
||||
ujwt = c.opts.JWT
|
||||
}
|
||||
}
|
||||
if ujwt == _EMPTY_ {
|
||||
// Didn't fall through with a valid NATS JWT, so try the default sentinel
|
||||
// if configured.
|
||||
if opts.DefaultSentinel != _EMPTY_ {
|
||||
c.opts.JWT = opts.DefaultSentinel
|
||||
ujwt = c.opts.JWT
|
||||
}
|
||||
}
|
||||
if ujwt == _EMPTY_ {
|
||||
s.mu.Unlock()
|
||||
c.Debugf("Authentication requires a user JWT")
|
||||
return false
|
||||
}
|
||||
// So we have a valid user jwt here.
|
||||
juc, err = jwt.DecodeUserClaims(ujwt)
|
||||
if err != nil {
|
||||
if juc, err = jwt.DecodeUserClaims(ujwt); err != nil {
|
||||
s.mu.Unlock()
|
||||
c.Debugf("User JWT not valid: %v", err)
|
||||
return false
|
||||
@@ -1015,8 +1031,10 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) (au
|
||||
c.Debugf("Connection type not allowed")
|
||||
return false
|
||||
}
|
||||
// skip validation of nonce when presented with a bearer token
|
||||
// FIXME: if BearerToken is only for WSS, need check for server with that port enabled
|
||||
// Skip validation of nonce when presented with a bearer token.
|
||||
// While support for bearer tokens was added for WebSockets, there is no
|
||||
// security benefit in restricting their use to that client protocol: the
|
||||
// client can just go use the other protocol.
|
||||
if !juc.BearerToken {
|
||||
// Verify the signature against the nonce.
|
||||
if c.opts.Sig == _EMPTY_ {
|
||||
|
||||
+10
-4
@@ -41,7 +41,7 @@ func titleCase(m string) string {
|
||||
}
|
||||
|
||||
// Process a callout on this client's behalf.
|
||||
func (s *Server) processClientOrLeafCallout(c *client, opts *Options, proxyRequired, trustedProxy bool) (authorized bool, errStr string) {
|
||||
func (s *Server) processClientOrLeafCallout(c *client, opts *Options, proxyRequired, trustedProxy bool, ujwt string) (authorized bool, errStr string) {
|
||||
isOperatorMode := len(opts.TrustedKeys) > 0
|
||||
|
||||
// this is the account the user connected in, or the one running the callout
|
||||
@@ -374,7 +374,7 @@ func (s *Server) processClientOrLeafCallout(c *client, opts *Options, proxyRequi
|
||||
// Grab client info for the request.
|
||||
c.mu.Lock()
|
||||
c.fillClientInfo(&claim.ClientInformation)
|
||||
c.fillConnectOpts(&claim.ConnectOptions)
|
||||
c.fillConnectOpts(&claim.ConnectOptions, ujwt)
|
||||
// If we have a sig in the client opts, fill in nonce.
|
||||
if claim.ConnectOptions.SignedNonce != _EMPTY_ {
|
||||
claim.ClientInformation.Nonce = string(c.nonce)
|
||||
@@ -474,16 +474,22 @@ func (c *client) fillClientInfo(ci *jwt.ClientInformation) {
|
||||
|
||||
// Fill in client options.
|
||||
// Lock should be held.
|
||||
func (c *client) fillConnectOpts(opts *jwt.ConnectOptions) {
|
||||
func (c *client) fillConnectOpts(opts *jwt.ConnectOptions, ujwt string) {
|
||||
if c == nil || (c.kind != CLIENT && c.kind != LEAF && c.kind != JETSTREAM && c.kind != ACCOUNT) {
|
||||
return
|
||||
}
|
||||
|
||||
o := c.opts
|
||||
if ujwt == _EMPTY_ {
|
||||
// The caller may supply a reconstructed JWT that should be sent to auth
|
||||
// callout without storing it in c.opts.JWT. If not, fall back to the client
|
||||
// option as before.
|
||||
ujwt = o.JWT
|
||||
}
|
||||
|
||||
// Do it this way to fail to compile if fields are added to jwt.ClientInformation.
|
||||
*opts = jwt.ConnectOptions{
|
||||
JWT: o.JWT,
|
||||
JWT: ujwt,
|
||||
Nkey: o.Nkey,
|
||||
SignedNonce: o.Sig,
|
||||
Token: o.Token,
|
||||
|
||||
+2
-2
@@ -239,7 +239,7 @@ func (ss SequenceSet) EncodeLen() int {
|
||||
return minLen + (ss.Nodes() * ((numBuckets+1)*8 + 2))
|
||||
}
|
||||
|
||||
func (ss SequenceSet) Encode(buf []byte) ([]byte, error) {
|
||||
func (ss SequenceSet) Encode(buf []byte) []byte {
|
||||
nn, encLen := ss.Nodes(), ss.EncodeLen()
|
||||
|
||||
if cap(buf) < encLen {
|
||||
@@ -268,7 +268,7 @@ func (ss SequenceSet) Encode(buf []byte) ([]byte, error) {
|
||||
le.PutUint16(buf[i:], uint16(n.h))
|
||||
i += 2
|
||||
})
|
||||
return buf[:i], nil
|
||||
return buf[:i]
|
||||
}
|
||||
|
||||
// ErrBadEncoding is returned when we can not decode properly.
|
||||
|
||||
+191
-122
@@ -1061,18 +1061,19 @@ func (c *client) setPermissions(perms *Permissions) {
|
||||
return
|
||||
}
|
||||
c.perms = &permissions{}
|
||||
slcache := c.srv != nil && !c.srv.getOpts().NoSublistCache
|
||||
|
||||
// Loop over publish permissions
|
||||
if perms.Publish != nil {
|
||||
if perms.Publish.Allow != nil {
|
||||
c.perms.pub.allow = NewSublistWithCache()
|
||||
c.perms.pub.allow = NewSublist(slcache)
|
||||
}
|
||||
for _, pubSubject := range perms.Publish.Allow {
|
||||
sub := &subscription{subject: []byte(pubSubject)}
|
||||
c.perms.pub.allow.Insert(sub)
|
||||
}
|
||||
if len(perms.Publish.Deny) > 0 {
|
||||
c.perms.pub.deny = NewSublistWithCache()
|
||||
c.perms.pub.deny = NewSublist(slcache)
|
||||
}
|
||||
for _, pubSubject := range perms.Publish.Deny {
|
||||
sub := &subscription{subject: []byte(pubSubject)}
|
||||
@@ -1091,7 +1092,7 @@ func (c *client) setPermissions(perms *Permissions) {
|
||||
if perms.Subscribe != nil {
|
||||
var err error
|
||||
if len(perms.Subscribe.Allow) > 0 {
|
||||
c.perms.sub.allow = NewSublistWithCache()
|
||||
c.perms.sub.allow = NewSublist(slcache)
|
||||
}
|
||||
for _, subSubject := range perms.Subscribe.Allow {
|
||||
sub := &subscription{}
|
||||
@@ -1103,7 +1104,7 @@ func (c *client) setPermissions(perms *Permissions) {
|
||||
c.perms.sub.allow.Insert(sub)
|
||||
}
|
||||
if len(perms.Subscribe.Deny) > 0 {
|
||||
c.perms.sub.deny = NewSublistWithCache()
|
||||
c.perms.sub.deny = NewSublist(slcache)
|
||||
// Also hold onto this array for later.
|
||||
c.darray = perms.Subscribe.Deny
|
||||
}
|
||||
@@ -1200,6 +1201,7 @@ func (c *client) mergeDenyPermissions(what denyType, denyPubs []string) {
|
||||
if c.perms == nil {
|
||||
c.perms = &permissions{}
|
||||
}
|
||||
slcache := c.srv != nil && !c.srv.getOpts().NoSublistCache
|
||||
var perms []*perm
|
||||
switch what {
|
||||
case pub:
|
||||
@@ -1211,7 +1213,7 @@ func (c *client) mergeDenyPermissions(what denyType, denyPubs []string) {
|
||||
}
|
||||
for _, p := range perms {
|
||||
if p.deny == nil {
|
||||
p.deny = NewSublistWithCache()
|
||||
p.deny = NewSublist(slcache)
|
||||
}
|
||||
FOR_DENY:
|
||||
for _, subj := range denyPubs {
|
||||
@@ -2254,12 +2256,20 @@ func (c *client) processConnect(arg []byte) error {
|
||||
// least ClientProtoInfo, we need to increment the following counter.
|
||||
// This is decremented when client is removed from the server's
|
||||
// clients map.
|
||||
if kind == CLIENT && proto >= ClientProtoInfo {
|
||||
if kind == CLIENT && proto >= ClientProtoInfo && firstConnect {
|
||||
srv.mu.Lock()
|
||||
srv.cproto++
|
||||
srv.mu.Unlock()
|
||||
}
|
||||
|
||||
// A second CONNECT may move the client into a different account via
|
||||
// checkAuthentication. Drop any previously-registered subscriptions
|
||||
// from the current account first so they don't leak in that account's
|
||||
// sublist after the client switches.
|
||||
if !firstConnect {
|
||||
c.clearAccountSubs(false)
|
||||
}
|
||||
|
||||
// Check for Auth
|
||||
if ok := srv.checkAuthentication(c); !ok {
|
||||
// We may fail here because we reached max limits on an account.
|
||||
@@ -3273,19 +3283,20 @@ func (c *client) canSubscribe(subject string, optQueue ...string) bool {
|
||||
r := c.perms.sub.deny.Match(subject)
|
||||
allowed = len(r.psubs) == 0
|
||||
|
||||
if queue != _EMPTY_ && len(r.qsubs) > 0 {
|
||||
if allowed && queue != _EMPTY_ && len(r.qsubs) > 0 {
|
||||
// If the queue appears in the deny list, then DO NOT allow.
|
||||
allowed = !queueMatches(queue, r.qsubs)
|
||||
}
|
||||
|
||||
// We use the actual subscription to signal us to spin up the deny mperms
|
||||
// and cache. We check if the subject is a wildcard that contains any of
|
||||
// and cache. We check if the subject is a wildcard that intersects any of
|
||||
// the deny clauses.
|
||||
// FIXME(dlc) - We could be smarter and track when these go away and remove.
|
||||
if allowed && c.mperms == nil && subjectHasWildcard(subject) {
|
||||
// Whip through the deny array and check if this wildcard subject is within scope.
|
||||
// Whip through the deny array and check if this wildcard subject can
|
||||
// overlap with any denied deliveries.
|
||||
for _, sub := range c.darray {
|
||||
if subjectIsSubsetMatch(sub, subject) {
|
||||
if SubjectsCollide(sub, subject) {
|
||||
c.loadMsgDenyFilter()
|
||||
break
|
||||
}
|
||||
@@ -3658,14 +3669,7 @@ func (c *client) deliverMsg(prodIsMQTT bool, sub *subscription, acc *Account, su
|
||||
|
||||
// Check if we are a leafnode and have perms to check.
|
||||
if client.kind == LEAF && client.perms != nil {
|
||||
var subjectToCheck []byte
|
||||
if subject[0] == '_' && bytes.HasPrefix(subject, []byte(gwReplyPrefix)) {
|
||||
subjectToCheck = subject[gwSubjectOffset:]
|
||||
} else if subject[0] == '$' && bytes.HasPrefix(subject, []byte(oldGWReplyPrefix)) {
|
||||
subjectToCheck = subject[oldGWReplyStart:]
|
||||
} else {
|
||||
subjectToCheck = subject
|
||||
}
|
||||
subjectToCheck, _ := getGWRoutedSubjectOrSelf(subject)
|
||||
if !client.pubAllowedFullCheck(string(subjectToCheck), true, true) {
|
||||
mt.addEgressEvent(client, sub, errMsgTracePubViolation)
|
||||
client.mu.Unlock()
|
||||
@@ -4068,7 +4072,7 @@ func (c *client) allowedMsgTraceDest(hdr []byte, hasLock bool) (string, bool) {
|
||||
return _EMPTY_, true
|
||||
}
|
||||
td := sliceHeader(MsgTraceDest, hdr)
|
||||
if len(td) == 0 {
|
||||
if len(td) == 0 || bytes.Equal(td, traceDestDisabledAsBytes) {
|
||||
return _EMPTY_, true
|
||||
}
|
||||
dest := bytesToString(td)
|
||||
@@ -4131,17 +4135,7 @@ func (c *client) pubAllowedFullCheck(subject string, fullCheck, hasLock bool) bo
|
||||
if !hasLock {
|
||||
c.mu.Lock()
|
||||
}
|
||||
if resp := c.replies[subject]; resp != nil {
|
||||
resp.n++
|
||||
// Check if we have sent too many responses.
|
||||
if c.perms.resp.MaxMsgs > 0 && resp.n > c.perms.resp.MaxMsgs {
|
||||
delete(c.replies, subject)
|
||||
} else if c.perms.resp.Expires > 0 && time.Since(resp.t) > c.perms.resp.Expires {
|
||||
delete(c.replies, subject)
|
||||
} else {
|
||||
allowed = true
|
||||
}
|
||||
}
|
||||
allowed = c.responseAllowed(subject)
|
||||
if !hasLock {
|
||||
c.mu.Unlock()
|
||||
}
|
||||
@@ -4155,6 +4149,25 @@ func (c *client) pubAllowedFullCheck(subject string, fullCheck, hasLock bool) bo
|
||||
return allowed
|
||||
}
|
||||
|
||||
// Returns true if this subject matches a tracked dynamic reply permission.
|
||||
// Lock must be held.
|
||||
func (c *client) responseAllowed(subject string) bool {
|
||||
if c.perms == nil || c.perms.resp == nil {
|
||||
return false
|
||||
}
|
||||
if resp := c.replies[subject]; resp != nil {
|
||||
resp.n++
|
||||
if c.perms.resp.MaxMsgs > 0 && resp.n > c.perms.resp.MaxMsgs {
|
||||
delete(c.replies, subject)
|
||||
} else if c.perms.resp.Expires > 0 && time.Since(resp.t) > c.perms.resp.Expires {
|
||||
delete(c.replies, subject)
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Test whether a reply subject is a service import reply.
|
||||
func isServiceReply(reply []byte) bool {
|
||||
// This function is inlined and checking this way is actually faster
|
||||
@@ -4162,16 +4175,51 @@ func isServiceReply(reply []byte) bool {
|
||||
return len(reply) > 3 && bytesToString(reply[:4]) == replyPrefix
|
||||
}
|
||||
|
||||
// Test whether a subject is a JetStream ACK.
|
||||
func isJSAckSubject(subject []byte) bool {
|
||||
return len(subject) > jsAckPreLen && bytesToString(subject[:jsAckPreLen]) == jsAckPre
|
||||
}
|
||||
|
||||
// jsAckDeliverIdx returns the byte offset of the `@` separator in an encoded
|
||||
// `$JS.ACK....@<deliver>` reply, or -1 if reply is not in that form. Stream,
|
||||
// consumer, and subject tokens may legally contain `@`, so we accept only the
|
||||
// first `@` that follows the eight dots of the JS ACK token:
|
||||
//
|
||||
// $JS.ACK.<stream>.<consumer>.<delivered>.<sseq>.<cseq>.<tm>.<pending>@<deliver>
|
||||
func jsAckDeliverIdx(reply []byte) int {
|
||||
if !isJSAckSubject(reply) {
|
||||
return -1
|
||||
}
|
||||
dots := 0
|
||||
for i, b := range reply {
|
||||
switch b {
|
||||
case '.':
|
||||
dots++
|
||||
case '@':
|
||||
if dots >= 8 {
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// replyHasJSAckSuffix reports whether reply is already in `$JS.ACK....@<deliver>`
|
||||
// form, so callers don't double-append the suffix on a re-entrant pass
|
||||
// (service-import or chained JS push).
|
||||
func replyHasJSAckSuffix(reply []byte) bool {
|
||||
return jsAckDeliverIdx(reply) != -1
|
||||
}
|
||||
|
||||
// Test whether a reply subject is a service import or a gateway routed reply.
|
||||
func isReservedReply(reply []byte) bool {
|
||||
if isServiceReply(reply) {
|
||||
return true
|
||||
}
|
||||
rLen := len(reply)
|
||||
// Faster to check with string([:]) than byte-by-byte
|
||||
if rLen > jsAckPreLen && bytesToString(reply[:jsAckPreLen]) == jsAckPre {
|
||||
if isJSAckSubject(reply) {
|
||||
return true
|
||||
} else if rLen > gwReplyPrefixLen && bytesToString(reply[:gwReplyPrefixLen]) == gwReplyPrefix {
|
||||
} else if len(reply) > gwReplyPrefixLen && bytesToString(reply[:gwReplyPrefixLen]) == gwReplyPrefix {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@@ -4370,7 +4418,7 @@ func (c *client) processInboundClientMsg(msg []byte) (bool, bool) {
|
||||
// Now deal with gateways
|
||||
if c.srv.gateway.enabled {
|
||||
reply := c.pa.reply
|
||||
if len(c.pa.deliver) > 0 && c.kind == JETSTREAM && len(c.pa.reply) > 0 {
|
||||
if len(c.pa.deliver) > 0 && c.kind == JETSTREAM && len(reply) > 0 && !replyHasJSAckSuffix(reply) {
|
||||
reply = append(reply, '@')
|
||||
reply = append(reply, c.pa.deliver...)
|
||||
}
|
||||
@@ -4418,7 +4466,7 @@ func (c *client) handleGWReplyMap(msg []byte) bool {
|
||||
}
|
||||
if c.srv.gateway.enabled {
|
||||
reply := c.pa.reply
|
||||
if len(c.pa.deliver) > 0 && c.kind == JETSTREAM && len(c.pa.reply) > 0 {
|
||||
if len(c.pa.deliver) > 0 && c.kind == JETSTREAM && len(reply) > 0 && !replyHasJSAckSuffix(reply) {
|
||||
reply = append(reply, '@')
|
||||
reply = append(reply, c.pa.deliver...)
|
||||
}
|
||||
@@ -4531,7 +4579,8 @@ func (c *client) setHeader(key, value string, msg []byte) []byte {
|
||||
// Write original header if present.
|
||||
if c.pa.hdr > LEN_CR_LF {
|
||||
omi = c.pa.hdr
|
||||
hdr := removeHeaderIfPresent(msg[:c.pa.hdr-LEN_CR_LF], key)
|
||||
// Need to copy since we're removing the header in place.
|
||||
hdr := removeHeaderIfPresent(copyBytes(msg[:c.pa.hdr-LEN_CR_LF]), key)
|
||||
if len(hdr) == 0 {
|
||||
bb.WriteString(hdrLine)
|
||||
} else {
|
||||
@@ -4825,6 +4874,12 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt
|
||||
// but the local server must replace it with the identity of the
|
||||
// authenticated leaf connection instead of trusting forwarded values.
|
||||
ci = c.getClientInfo(share)
|
||||
if hadPrevSi && cis != nil && cis.Reply != _EMPTY_ {
|
||||
ci.Reply = cis.Reply
|
||||
} else if bytes.HasSuffix(c.pa.reply, []byte(FastBatchSuffix)) {
|
||||
// Fast batch requires knowledge of the original reply subject.
|
||||
ci.Reply = bytesToString(c.pa.reply)
|
||||
}
|
||||
if hadPrevSi {
|
||||
ci.Service = acc.Name
|
||||
if !share && (si.share || isSysImport) {
|
||||
@@ -4843,6 +4898,10 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt
|
||||
}
|
||||
} else if c.kind != LEAF || c.pa.hdr < 0 || len(sliceHeader(ClientInfoHdr, msg[:c.pa.hdr])) == 0 {
|
||||
ci = c.getClientInfo(share)
|
||||
// Fast batch requires knowledge of the original reply subject.
|
||||
if bytes.HasSuffix(c.pa.reply, []byte(FastBatchSuffix)) {
|
||||
ci.Reply = bytesToString(c.pa.reply)
|
||||
}
|
||||
// If we did not share but the imports destination is the system account add in the server and cluster info.
|
||||
if !share && isSysImport {
|
||||
c.addServerAndClusterInfo(ci)
|
||||
@@ -4902,8 +4961,7 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt
|
||||
// We also need to disable the message trace headers so that
|
||||
// if the message is routed, it does not initialize tracing in the
|
||||
// remote.
|
||||
positions := disableTraceHeaders(c, msg)
|
||||
defer enableTraceHeaders(msg, positions)
|
||||
msg = c.setHeader(MsgTraceDest, MsgTraceDestDisabled, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5037,21 +5095,9 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver,
|
||||
// Check for JetStream encoded reply subjects.
|
||||
// For now these will only be on $JS.ACK prefixed reply subjects.
|
||||
var remapped bool
|
||||
if len(creply) > 0 && c.kind != CLIENT && !isInternalClient(c.kind) && bytes.HasPrefix(creply, []byte(jsAckPre)) {
|
||||
if len(creply) > 0 && c.kind != CLIENT && !isInternalClient(c.kind) {
|
||||
// We need to rewrite the subject and the reply.
|
||||
// But, we must be careful that the stream name, consumer name, and subject can contain '@' characters.
|
||||
// JS ACK contains at least 8 dots, find the first @ after this prefix.
|
||||
// - $JS.ACK.<stream>.<consumer>.<delivered>.<sseq>.<cseq>.<tm>.<pending>
|
||||
counter := 0
|
||||
li := bytes.IndexFunc(creply, func(rn rune) bool {
|
||||
if rn == '.' {
|
||||
counter++
|
||||
} else if rn == '@' {
|
||||
return counter >= 8
|
||||
}
|
||||
return false
|
||||
})
|
||||
if li != -1 && li < len(creply)-1 {
|
||||
if li := jsAckDeliverIdx(creply); li != -1 && li < len(creply)-1 {
|
||||
remapped = true
|
||||
subj, creply = creply[li+1:], creply[:li]
|
||||
}
|
||||
@@ -5471,7 +5517,7 @@ sendToRoutesOrLeafs:
|
||||
// at the end of the reply subject if it exists. But only if this wasn't
|
||||
// already performed, otherwise we'd end up with a duplicate '@' suffix
|
||||
// resulting in a protocol error.
|
||||
if len(deliver) > 0 && len(reply) > 0 && !remapped {
|
||||
if len(deliver) > 0 && len(reply) > 0 && !remapped && !replyHasJSAckSuffix(reply) {
|
||||
reply = append(reply, '@')
|
||||
reply = append(reply, deliver...)
|
||||
}
|
||||
@@ -5754,11 +5800,12 @@ func (c *client) clearAuthTimer() bool {
|
||||
return stopped
|
||||
}
|
||||
|
||||
// We may reuse atmr for expiring user jwts,
|
||||
// so check connectReceived.
|
||||
// Track whether the parser should still enforce pre-CONNECT rules.
|
||||
// This is handshake state, not timer state, since some handshakes
|
||||
// use a different timer while still expecting CONNECT.
|
||||
// Lock assume held on entry.
|
||||
func (c *client) awaitingAuth() bool {
|
||||
return !c.flags.isSet(connectReceived) && c.atmr != nil
|
||||
return c.flags.isSet(expectConnect) && !c.flags.isSet(connectReceived)
|
||||
}
|
||||
|
||||
// This will set the atmr for the JWT expiration time.
|
||||
@@ -5987,37 +6034,12 @@ func (c *client) closeConnection(reason ClosedState) {
|
||||
srv = c.srv
|
||||
noReconnect = c.flags.isSet(noReconnect)
|
||||
acc = c.acc
|
||||
spoke bool
|
||||
)
|
||||
|
||||
// Snapshot for use if we are a client connection.
|
||||
// FIXME(dlc) - we can just stub in a new one for client
|
||||
// and reference existing one.
|
||||
var subs []*subscription
|
||||
if kind == CLIENT || kind == LEAF || kind == JETSTREAM {
|
||||
var _subs [32]*subscription
|
||||
subs = _subs[:0]
|
||||
// Do not set c.subs to nil or delete the sub from c.subs here because
|
||||
// it will be needed in saveClosedClient (which has been started as a
|
||||
// go routine in markConnAsClosed). Cleanup will be done there.
|
||||
for _, sub := range c.subs {
|
||||
// Auto-unsubscribe subscriptions must be unsubscribed forcibly.
|
||||
sub.max = 0
|
||||
sub.close()
|
||||
subs = append(subs, sub)
|
||||
}
|
||||
spoke = c.isSpokeLeafNode()
|
||||
}
|
||||
|
||||
c.mu.Unlock()
|
||||
|
||||
// Remove client's or leaf node or jetstream subscriptions.
|
||||
if acc != nil && (kind == CLIENT || kind == LEAF || kind == JETSTREAM) {
|
||||
acc.sl.RemoveBatch(subs)
|
||||
} else if kind == ROUTER {
|
||||
if kind == ROUTER {
|
||||
c.removeRemoteSubs()
|
||||
}
|
||||
|
||||
if srv != nil {
|
||||
// Unregister
|
||||
srv.removeClient(c)
|
||||
@@ -6025,45 +6047,11 @@ func (c *client) closeConnection(reason ClosedState) {
|
||||
if acc != nil {
|
||||
// Update remote subscriptions.
|
||||
if kind == CLIENT || kind == LEAF || kind == JETSTREAM {
|
||||
qsubs := map[string]*qsub{}
|
||||
for _, sub := range subs {
|
||||
// Call unsubscribe here to cleanup shadow subscriptions and such.
|
||||
c.unsubscribe(acc, sub, true, false)
|
||||
// Update route as normal for a normal subscriber.
|
||||
if sub.queue == nil {
|
||||
if !spoke {
|
||||
srv.updateRouteSubscriptionMap(acc, sub, -1)
|
||||
if srv.gateway.enabled {
|
||||
srv.gatewayUpdateSubInterest(acc.Name, sub, -1)
|
||||
}
|
||||
}
|
||||
acc.updateLeafNodes(sub, -1)
|
||||
} else {
|
||||
// We handle queue subscribers special in case we
|
||||
// have a bunch we can just send one update to the
|
||||
// connected routes.
|
||||
num := int32(1)
|
||||
if kind == LEAF {
|
||||
num = sub.qw
|
||||
}
|
||||
key := keyFromSub(sub)
|
||||
if esub, ok := qsubs[key]; ok {
|
||||
esub.n += num
|
||||
} else {
|
||||
qsubs[key] = &qsub{sub, num}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Process any qsubs here.
|
||||
for _, esub := range qsubs {
|
||||
if !spoke {
|
||||
srv.updateRouteSubscriptionMap(acc, esub.sub, -(esub.n))
|
||||
if srv.gateway.enabled {
|
||||
srv.gatewayUpdateSubInterest(acc.Name, esub.sub, -(esub.n))
|
||||
}
|
||||
}
|
||||
acc.updateLeafNodes(esub.sub, -(esub.n))
|
||||
}
|
||||
// Remove client's subscriptions from the account and unregister
|
||||
// client from that account. Keep c.subs populated because
|
||||
// saveClosedClient (started as a goroutine in markConnAsClosed)
|
||||
// still needs to read it.
|
||||
c.clearAccountSubs(true)
|
||||
}
|
||||
// Always remove from the account, otherwise we can leak clients.
|
||||
// Note that SYSTEM and ACCOUNT types from above cleanup their own subs.
|
||||
@@ -6090,6 +6078,87 @@ func (c *client) closeConnection(reason ClosedState) {
|
||||
c.reconnect()
|
||||
}
|
||||
|
||||
// clearAccountSubs removes the client's subscriptions from its current account
|
||||
// and unregisters it from that account. If close is true, c.subs is left
|
||||
// populated for saveClosedClient; otherwise c.subs is cleared and c.acc
|
||||
// registered back to the global account.
|
||||
// Client lock MUST NOT be held on entry.
|
||||
func (c *client) clearAccountSubs(close bool) {
|
||||
c.mu.Lock()
|
||||
kind := c.kind
|
||||
srv := c.srv
|
||||
acc := c.acc
|
||||
if acc == nil || (kind != CLIENT && kind != LEAF && kind != JETSTREAM) {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
var _subs [32]*subscription
|
||||
subs := _subs[:0]
|
||||
// Do not set c.subs to nil or delete the sub from c.subs here because
|
||||
// it will be needed in saveClosedClient (which has been started as a
|
||||
// go routine in markConnAsClosed). Cleanup will be done there.
|
||||
for _, sub := range c.subs {
|
||||
// Auto-unsubscribe subscriptions must be unsubscribed forcibly.
|
||||
sub.max = 0
|
||||
sub.close()
|
||||
subs = append(subs, sub)
|
||||
if !close {
|
||||
delete(c.subs, string(sub.sid))
|
||||
}
|
||||
}
|
||||
spoke := c.isSpokeLeafNode()
|
||||
c.mu.Unlock()
|
||||
|
||||
acc.sl.RemoveBatch(subs)
|
||||
|
||||
if srv != nil {
|
||||
qsubs := map[string]*qsub{}
|
||||
for _, sub := range subs {
|
||||
// Call unsubscribe here to cleanup shadow subscriptions and such.
|
||||
c.unsubscribe(acc, sub, true, false)
|
||||
// Update route as normal for a normal subscriber.
|
||||
if sub.queue == nil {
|
||||
if !spoke {
|
||||
srv.updateRouteSubscriptionMap(acc, sub, -1)
|
||||
if srv.gateway.enabled {
|
||||
srv.gatewayUpdateSubInterest(acc.Name, sub, -1)
|
||||
}
|
||||
}
|
||||
acc.updateLeafNodes(sub, -1)
|
||||
} else {
|
||||
// We handle queue subscribers special in case we
|
||||
// have a bunch we can just send one update to the
|
||||
// connected routes.
|
||||
num := int32(1)
|
||||
if kind == LEAF {
|
||||
num = sub.qw
|
||||
}
|
||||
key := keyFromSub(sub)
|
||||
if esub, ok := qsubs[key]; ok {
|
||||
esub.n += num
|
||||
} else {
|
||||
qsubs[key] = &qsub{sub, num}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Process any qsubs here.
|
||||
for _, esub := range qsubs {
|
||||
if !spoke {
|
||||
srv.updateRouteSubscriptionMap(acc, esub.sub, -(esub.n))
|
||||
if srv.gateway.enabled {
|
||||
srv.gatewayUpdateSubInterest(acc.Name, esub.sub, -(esub.n))
|
||||
}
|
||||
}
|
||||
acc.updateLeafNodes(esub.sub, -(esub.n))
|
||||
}
|
||||
}
|
||||
|
||||
if !close {
|
||||
// Register back to global account, mimicking the state after client initialization.
|
||||
c.registerWithAccount(srv.globalAccount())
|
||||
}
|
||||
}
|
||||
|
||||
// Depending on the kind of connections, this may attempt to recreate a connection.
|
||||
// The actual reconnect attempt will be started in a go routine.
|
||||
func (c *client) reconnect() {
|
||||
@@ -6180,7 +6249,7 @@ func (c *client) reconnect() {
|
||||
srv.Debugf("Gateway %q not in configuration, not attempting reconnect", gwName)
|
||||
}
|
||||
} else if leafCfg != nil {
|
||||
// Check if this is a solicited leaf node. Start up a reconnect.
|
||||
// This is a solicited leaf node. Start up a reconnect.
|
||||
srv.startGoRoutine(func() { srv.reConnectToRemoteLeafNode(leafCfg) })
|
||||
}
|
||||
}
|
||||
|
||||
+1
-1
@@ -66,7 +66,7 @@ func init() {
|
||||
|
||||
const (
|
||||
// VERSION is the current version for the server.
|
||||
VERSION = "2.12.6"
|
||||
VERSION = "2.14.0"
|
||||
|
||||
// PROTO is the currently supported protocol.
|
||||
// 0 was the original
|
||||
|
||||
+563
-106
File diff suppressed because it is too large
Load Diff
+327
@@ -0,0 +1,327 @@
|
||||
// Copyright 2025 The NATS Authors
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Based on code from https://github.com/robfig/cron
|
||||
// Copyright (C) 2012 Rob Figueiredo
|
||||
// All Rights Reserved.
|
||||
//
|
||||
// MIT LICENSE
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
// this software and associated documentation files (the "Software"), to deal in
|
||||
// the Software without restriction, including without limitation the rights to
|
||||
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
// the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// parseCron parses the given cron pattern and returns the next time it will fire based on the provided ts.
|
||||
func parseCron(pattern string, loc *time.Location, ts int64) (time.Time, error) {
|
||||
fields := strings.Fields(pattern)
|
||||
if len(fields) != 6 {
|
||||
return time.Time{}, fmt.Errorf("pattern requires 6 fields, got %d", len(fields))
|
||||
}
|
||||
|
||||
// If no time zone is passed, default to UTC.
|
||||
if loc == nil {
|
||||
loc = time.UTC
|
||||
}
|
||||
|
||||
// Parse each field.
|
||||
var err error
|
||||
var second, minute, hour, dayOfMonth, month, dayOfWeek uint64
|
||||
if second, err = getField(fields[0], seconds); err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if minute, err = getField(fields[1], minutes); err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if hour, err = getField(fields[2], hours); err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if dayOfMonth, err = getField(fields[3], dom); err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if month, err = getField(fields[4], months); err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if dayOfWeek, err = getField(fields[5], dow); err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
// General approach
|
||||
//
|
||||
// For Month, Day, Hour, Minute, Second:
|
||||
// Check if the time value matches. If yes, continue to the next field.
|
||||
// If the field doesn't match the schedule, then increment the field until it matches.
|
||||
// While incrementing the field, a wrap-around brings it back to the beginning
|
||||
// of the field list (since it is necessary to re-verify previous field values)
|
||||
next := time.Unix(0, ts).In(loc)
|
||||
|
||||
// Start at the earliest possible time (the upcoming second).
|
||||
next = next.Truncate(time.Second).Add(time.Second)
|
||||
|
||||
// This flag indicates whether a field has been truncated at one point.
|
||||
truncated := false
|
||||
|
||||
// If no time is found within five years, return error.
|
||||
yearLimit := next.Year() + 5
|
||||
|
||||
WRAP:
|
||||
if next.Year() > yearLimit {
|
||||
return time.Time{}, errors.New("pattern exceeds maximum range")
|
||||
}
|
||||
for 1<<uint(next.Month())&month == 0 {
|
||||
if !truncated {
|
||||
truncated = true
|
||||
next = time.Date(next.Year(), next.Month(), 1, 0, 0, 0, 0, loc)
|
||||
}
|
||||
if next = next.AddDate(0, 1, 0); next.Month() == time.January {
|
||||
goto WRAP
|
||||
}
|
||||
}
|
||||
for !dayMatches(dayOfMonth, dayOfWeek, next) {
|
||||
if !truncated {
|
||||
truncated = true
|
||||
next = time.Date(next.Year(), next.Month(), next.Day(), 0, 0, 0, 0, loc)
|
||||
}
|
||||
if next = next.AddDate(0, 0, 1); next.Day() == 1 {
|
||||
goto WRAP
|
||||
}
|
||||
}
|
||||
for 1<<uint(next.Hour())&hour == 0 {
|
||||
if !truncated {
|
||||
truncated = true
|
||||
next = time.Date(next.Year(), next.Month(), next.Day(), next.Hour(), 0, 0, 0, loc)
|
||||
}
|
||||
if next = next.Add(time.Hour); next.Hour() == 0 {
|
||||
goto WRAP
|
||||
}
|
||||
}
|
||||
for 1<<uint(next.Minute())&minute == 0 {
|
||||
if !truncated {
|
||||
truncated = true
|
||||
next = next.Truncate(time.Minute)
|
||||
}
|
||||
if next = next.Add(time.Minute); next.Minute() == 0 {
|
||||
goto WRAP
|
||||
}
|
||||
}
|
||||
for 1<<uint(next.Second())&second == 0 {
|
||||
if !truncated {
|
||||
truncated = true
|
||||
next = next.Truncate(time.Second)
|
||||
}
|
||||
if next = next.Add(time.Second); next.Second() == 0 {
|
||||
goto WRAP
|
||||
}
|
||||
}
|
||||
return next, nil
|
||||
}
|
||||
|
||||
// getField returns an Int with the bits set representing all of the times that
|
||||
// the field represents or error parsing field value. A "field" is a comma-separated
|
||||
// list of "ranges".
|
||||
func getField(field string, r bounds) (uint64, error) {
|
||||
var bits uint64
|
||||
ranges := strings.FieldsFuncSeq(field, func(r rune) bool { return r == ',' })
|
||||
for expr := range ranges {
|
||||
bit, err := getRange(expr, r)
|
||||
if err != nil {
|
||||
return bits, err
|
||||
}
|
||||
bits |= bit
|
||||
}
|
||||
return bits, nil
|
||||
}
|
||||
|
||||
// getRange returns the bits indicated by the given expression: number | number [ "-" number ] [ "/" number ]
|
||||
// or error parsing range.
|
||||
func getRange(expr string, r bounds) (uint64, error) {
|
||||
var (
|
||||
start, end, step uint
|
||||
rangeAndStep = strings.Split(expr, "/")
|
||||
lowAndHigh = strings.Split(rangeAndStep[0], "-")
|
||||
singleDigit = len(lowAndHigh) == 1
|
||||
err error
|
||||
)
|
||||
|
||||
var extra uint64
|
||||
if lowAndHigh[0] == "*" || lowAndHigh[0] == "?" {
|
||||
start = r.min
|
||||
end = r.max
|
||||
extra = starBit
|
||||
} else {
|
||||
start, err = parseIntOrName(lowAndHigh[0], r.names)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch len(lowAndHigh) {
|
||||
case 1:
|
||||
end = start
|
||||
case 2:
|
||||
end, err = parseIntOrName(lowAndHigh[1], r.names)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("too many hyphens: %s", expr)
|
||||
}
|
||||
}
|
||||
|
||||
switch len(rangeAndStep) {
|
||||
case 1:
|
||||
step = 1
|
||||
case 2:
|
||||
step, err = mustParseInt(rangeAndStep[1])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// Special handling: "N/step" means "N-max/step".
|
||||
if singleDigit {
|
||||
end = r.max
|
||||
}
|
||||
if step > 1 {
|
||||
extra = 0
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("too many slashes: %s", expr)
|
||||
}
|
||||
|
||||
if start < r.min {
|
||||
return 0, fmt.Errorf("beginning of range (%d) below minimum (%d): %s", start, r.min, expr)
|
||||
}
|
||||
if end > r.max {
|
||||
return 0, fmt.Errorf("end of range (%d) above maximum (%d): %s", end, r.max, expr)
|
||||
}
|
||||
if start > end {
|
||||
return 0, fmt.Errorf("beginning of range (%d) beyond end of range (%d): %s", start, end, expr)
|
||||
}
|
||||
if step == 0 {
|
||||
return 0, fmt.Errorf("step of range should be a positive number: %s", expr)
|
||||
}
|
||||
return getBits(start, end, step) | extra, nil
|
||||
}
|
||||
|
||||
// parseIntOrName returns the (possibly-named) integer contained in expr.
|
||||
func parseIntOrName(expr string, names map[string]uint) (uint, error) {
|
||||
if names != nil {
|
||||
if namedInt, ok := names[strings.ToLower(expr)]; ok {
|
||||
return namedInt, nil
|
||||
}
|
||||
}
|
||||
return mustParseInt(expr)
|
||||
}
|
||||
|
||||
// mustParseInt parses the given expression as an int or returns an error.
|
||||
func mustParseInt(expr string) (uint, error) {
|
||||
num, err := strconv.Atoi(expr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to parse int from %s: %s", expr, err)
|
||||
}
|
||||
if num < 0 {
|
||||
return 0, fmt.Errorf("negative number (%d) not allowed: %s", num, expr)
|
||||
}
|
||||
return uint(num), nil
|
||||
}
|
||||
|
||||
// getBits sets all bits in the range [min, max], modulo the given step size.
|
||||
func getBits(min, max, step uint) uint64 {
|
||||
var bits uint64
|
||||
|
||||
// If step is 1, use shifts.
|
||||
if step == 1 {
|
||||
return ^(math.MaxUint64 << (max + 1)) & (math.MaxUint64 << min)
|
||||
}
|
||||
|
||||
// Else, use a simple loop.
|
||||
for i := min; i <= max; i += step {
|
||||
bits |= 1 << i
|
||||
}
|
||||
return bits
|
||||
}
|
||||
|
||||
// bounds provides a range of acceptable values (plus a map of name to value).
|
||||
type bounds struct {
|
||||
min, max uint
|
||||
names map[string]uint
|
||||
}
|
||||
|
||||
// The bounds for each field.
|
||||
var (
|
||||
seconds = bounds{0, 59, nil}
|
||||
minutes = bounds{0, 59, nil}
|
||||
hours = bounds{0, 23, nil}
|
||||
dom = bounds{1, 31, nil}
|
||||
months = bounds{1, 12, map[string]uint{
|
||||
"jan": 1,
|
||||
"feb": 2,
|
||||
"mar": 3,
|
||||
"apr": 4,
|
||||
"may": 5,
|
||||
"jun": 6,
|
||||
"jul": 7,
|
||||
"aug": 8,
|
||||
"sep": 9,
|
||||
"oct": 10,
|
||||
"nov": 11,
|
||||
"dec": 12,
|
||||
}}
|
||||
dow = bounds{0, 6, map[string]uint{
|
||||
"sun": 0,
|
||||
"mon": 1,
|
||||
"tue": 2,
|
||||
"wed": 3,
|
||||
"thu": 4,
|
||||
"fri": 5,
|
||||
"sat": 6,
|
||||
}}
|
||||
)
|
||||
|
||||
const (
|
||||
// Set the top bit if a star was included in the expression.
|
||||
starBit = 1 << 63
|
||||
)
|
||||
|
||||
// dayMatches returns true if the schedule's day-of-week and day-of-month
|
||||
// restrictions are satisfied by the given time.
|
||||
func dayMatches(dayOfMonth, dayOfWeek uint64, t time.Time) bool {
|
||||
var (
|
||||
domMatch = 1<<uint(t.Day())&dayOfMonth > 0
|
||||
dowMatch = 1<<uint(t.Weekday())&dayOfWeek > 0
|
||||
)
|
||||
if dayOfMonth&starBit > 0 || dayOfWeek&starBit > 0 {
|
||||
return domMatch && dowMatch
|
||||
}
|
||||
return domMatch || dowMatch
|
||||
}
|
||||
+211
-1
@@ -2008,5 +2008,215 @@
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
},
|
||||
{
|
||||
"constant": "JSMessageSchedulesSourceInvalidErr",
|
||||
"code": 400,
|
||||
"error_code": 10203,
|
||||
"description": "message schedules source is invalid",
|
||||
"comment": "",
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
},
|
||||
{
|
||||
"constant": "JSConsumerInvalidResetErr",
|
||||
"code": 400,
|
||||
"error_code": 10204,
|
||||
"description": "invalid reset: {err}",
|
||||
"comment": "",
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
},
|
||||
{
|
||||
"constant": "JSBatchPublishDisabledErr",
|
||||
"code": 400,
|
||||
"error_code": 10205,
|
||||
"description": "batch publish is disabled",
|
||||
"comment": "",
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
},
|
||||
{
|
||||
"constant": "JSBatchPublishInvalidPatternErr",
|
||||
"code": 400,
|
||||
"error_code": 10206,
|
||||
"description": "batch publish pattern is invalid",
|
||||
"comment": "",
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
},
|
||||
{
|
||||
"constant": "JSBatchPublishInvalidBatchIDErr",
|
||||
"code": 400,
|
||||
"error_code": 10207,
|
||||
"description": "batch publish ID is invalid",
|
||||
"comment": "",
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
},
|
||||
{
|
||||
"constant": "JSBatchPublishUnknownBatchIDErr",
|
||||
"code": 400,
|
||||
"error_code": 10208,
|
||||
"description": "batch publish ID unknown",
|
||||
"comment": "",
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
},
|
||||
{
|
||||
"constant": "JSMirrorWithBatchPublishErr",
|
||||
"code": 400,
|
||||
"error_code": 10209,
|
||||
"description": "stream mirrors can not also use batch publishing",
|
||||
"comment": "",
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
},
|
||||
{
|
||||
"constant": "JSAtomicPublishTooManyInflight",
|
||||
"code": 429,
|
||||
"error_code": 10210,
|
||||
"description": "atomic publish too many inflight",
|
||||
"comment": "",
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
},
|
||||
{
|
||||
"constant": "JSBatchPublishTooManyInflight",
|
||||
"code": 429,
|
||||
"error_code": 10211,
|
||||
"description": "batch publish too many inflight",
|
||||
"comment": "",
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
},
|
||||
{
|
||||
"constant": "JSMessageSchedulesSchedulerInvalidErr",
|
||||
"code": 400,
|
||||
"error_code": 10212,
|
||||
"description": "message schedules invalid scheduler",
|
||||
"comment": "",
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
},
|
||||
{
|
||||
"constant": "JSMirrorDurableConsumerCfgInvalid",
|
||||
"code": 400,
|
||||
"error_code": 10213,
|
||||
"description": "stream mirror consumer config is invalid",
|
||||
"comment": "",
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
},
|
||||
{
|
||||
"constant": "JSMirrorConsumerRequiresAckFCErr",
|
||||
"code": 400,
|
||||
"error_code": 10214,
|
||||
"description": "stream mirror consumer requires flow control ack policy",
|
||||
"comment": "",
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
},
|
||||
{
|
||||
"constant": "JSSourceDurableConsumerCfgInvalid",
|
||||
"code": 400,
|
||||
"error_code": 10215,
|
||||
"description": "stream source consumer config is invalid",
|
||||
"comment": "",
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
},
|
||||
{
|
||||
"constant": "JSSourceDurableConsumerDuplicateDetected",
|
||||
"code": 400,
|
||||
"error_code": 10216,
|
||||
"description": "duplicate stream source consumer detected",
|
||||
"comment": "",
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
},
|
||||
{
|
||||
"constant": "JSSourceConsumerRequiresAckFCErr",
|
||||
"code": 400,
|
||||
"error_code": 10217,
|
||||
"description": "stream source consumer requires flow control ack policy",
|
||||
"comment": "",
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
},
|
||||
{
|
||||
"constant": "JSConsumerAckFCRequiresPushErr",
|
||||
"code": 400,
|
||||
"error_code": 10218,
|
||||
"description": "flow control ack policy requires a push based consumer",
|
||||
"comment": "",
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
},
|
||||
{
|
||||
"constant": "JSConsumerAckFCRequiresFCErr",
|
||||
"code": 400,
|
||||
"error_code": 10219,
|
||||
"description": "flow control ack policy requires flow control",
|
||||
"comment": "",
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
},
|
||||
{
|
||||
"constant": "JSConsumerAckFCRequiresMaxAckPendingErr",
|
||||
"code": 400,
|
||||
"error_code": 10220,
|
||||
"description": "flow control ack policy requires max ack pending",
|
||||
"comment": "",
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
},
|
||||
{
|
||||
"constant": "JSConsumerAckFCRequiresNoAckWaitErr",
|
||||
"code": 400,
|
||||
"error_code": 10221,
|
||||
"description": "flow control ack policy requires unset ack wait",
|
||||
"comment": "",
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
},
|
||||
{
|
||||
"constant": "JSConsumerAckFCRequiresNoMaxDeliverErr",
|
||||
"code": 400,
|
||||
"error_code": 10222,
|
||||
"description": "flow control ack policy requires unset max deliver",
|
||||
"comment": "",
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
},
|
||||
{
|
||||
"constant": "JSMessageSchedulesTimeZoneInvalidErr",
|
||||
"code": 400,
|
||||
"error_code": 10223,
|
||||
"description": "message schedules time zone is invalid",
|
||||
"comment": "",
|
||||
"help": "",
|
||||
"url": "",
|
||||
"deprecates": ""
|
||||
}
|
||||
]
|
||||
]
|
||||
+21
-11
@@ -247,14 +247,15 @@ type ServerCapability uint64
|
||||
|
||||
// ServerInfo identifies remote servers.
|
||||
type ServerInfo struct {
|
||||
Name string `json:"name"`
|
||||
Host string `json:"host"`
|
||||
ID string `json:"id"`
|
||||
Cluster string `json:"cluster,omitempty"`
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Version string `json:"ver"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Host string `json:"host"`
|
||||
ID string `json:"id"`
|
||||
Cluster string `json:"cluster,omitempty"`
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Version string `json:"ver"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
FeatureFlags map[string]bool `json:"feature_flags,omitempty"`
|
||||
// Whether JetStream is enabled (deprecated in favor of the `ServerCapability`).
|
||||
JetStream bool `json:"jetstream"`
|
||||
// Generic capability flags
|
||||
@@ -328,6 +329,7 @@ type ClientInfo struct {
|
||||
ClientType string `json:"client_type,omitempty"`
|
||||
MQTTClient string `json:"client_id,omitempty"` // This is the MQTT client ID
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
Reply string `json:"reply,omitempty"` // Original reply subject after a service import (only when needed).
|
||||
}
|
||||
|
||||
// forAssignmentSnap returns the minimum amount of ClientInfo we need for assignment snapshots.
|
||||
@@ -518,7 +520,7 @@ RESET:
|
||||
|
||||
// Grab tags and metadata.
|
||||
opts := s.getOpts()
|
||||
tags, metadata := opts.Tags, opts.Metadata
|
||||
tags, metadata, featureFlags := opts.Tags, opts.Metadata, opts.getMergedFeatureFlags()
|
||||
|
||||
for s.eventsRunning() {
|
||||
select {
|
||||
@@ -536,6 +538,7 @@ RESET:
|
||||
si.Time = time.Now().UTC()
|
||||
si.Tags = tags
|
||||
si.Metadata = metadata
|
||||
si.FeatureFlags = featureFlags
|
||||
si.Flags = 0
|
||||
if js {
|
||||
// New capability based flags.
|
||||
@@ -1052,8 +1055,15 @@ func (s *Server) sendStatsz(subj string) {
|
||||
Size: mg.ClusterSize(),
|
||||
}
|
||||
}
|
||||
if ipq := s.jsAPIRoutedReqs; ipq != nil && jStat.Meta != nil {
|
||||
jStat.Meta.Pending = ipq.len()
|
||||
if jStat.Meta != nil {
|
||||
if ipq := s.jsAPIRoutedReqs; ipq != nil {
|
||||
jStat.Meta.PendingRequests = ipq.len()
|
||||
}
|
||||
if ipq := s.jsAPIRoutedInfoReqs; ipq != nil {
|
||||
jStat.Meta.PendingInfos = ipq.len()
|
||||
}
|
||||
jStat.Meta.Pending = jStat.Meta.PendingRequests + jStat.Meta.PendingInfos
|
||||
jStat.Meta.Snapshot = s.metaClusterSnapshotStats(js, mg)
|
||||
}
|
||||
}
|
||||
jStat.Limits = &s.getOpts().JetStreamLimits
|
||||
|
||||
+130
@@ -0,0 +1,130 @@
|
||||
// Copyright 2026 The NATS Authors
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"maps"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
FeatureFlagJsAckFormatV2 = "js_ack_fc_v2"
|
||||
FeatureFlagJsRaftDeleteRange = "js_raft_delete_range"
|
||||
)
|
||||
|
||||
var featureFlags = map[string]bool{
|
||||
// Use v2 format for `$JS.ACK.>` and `$JS.FC.>`.
|
||||
// - Introduced: 2.14.0, both v1 and v2 supported, only using v1.
|
||||
// - Enabled: TBD, both supported, v2 becomes the default.
|
||||
//
|
||||
// - v1: $JS.ACK.<stream name>.<consumer name>.<num delivered>.<stream sequence>.<consumer sequence>.<timestamp>.<num pending>
|
||||
// - v2: $JS.ACK.<domain>.<account hash>.<stream name>.<consumer name>.<num delivered>.<stream sequence>.<consumer sequence>.<timestamp>.<num pending>
|
||||
// See also: https://github.com/nats-io/nats-architecture-and-design/blob/main/adr/ADR-15.md#jsack
|
||||
FeatureFlagJsAckFormatV2: false,
|
||||
|
||||
// Propose delete range gaps as a single `deleteRangeOp` Raft append entry
|
||||
// instead of one entry per deleted sequence. Dramatically reduces Raft cost
|
||||
// on mirrors whose origin has a large number of interior deletes.
|
||||
// - Introduced: 2.14.0, apply-side always supports receiving `deleteRangeOp`.
|
||||
// - Enabled: TBD, once all supported versions carry the apply-side.
|
||||
//
|
||||
// WARNING: Only enable once every peer in the cluster is on a version that
|
||||
// supports receiving `deleteRangeOp`. Older peers panic on apply of an
|
||||
// unknown stream entry operation.
|
||||
FeatureFlagJsRaftDeleteRange: false,
|
||||
}
|
||||
|
||||
// getFeatureFlag is used to retrieve either the default or overwritten value for a feature flag.
|
||||
// The user's value takes precedence over the system's default. However, if the flag doesn't exist, it's disabled.
|
||||
// The *Options returned by Server.getOpts() is treated as immutable, mutations go through setOpts,
|
||||
// so no lock is required on the map read here.
|
||||
func (o *Options) getFeatureFlag(k string) bool {
|
||||
defaultValue, ok := featureFlags[k]
|
||||
if !ok {
|
||||
return false // Not supported.
|
||||
}
|
||||
if userValue, ok := o.FeatureFlags[k]; ok {
|
||||
return userValue
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// getMergedFeatureFlags returns a merged map of feature flags, with the user's values taking precedence.
|
||||
func (o *Options) getMergedFeatureFlags() map[string]bool {
|
||||
merged := make(map[string]bool)
|
||||
for k, v := range featureFlags {
|
||||
merged[k] = v
|
||||
}
|
||||
for k, v := range o.FeatureFlags {
|
||||
if _, ok := featureFlags[k]; !ok {
|
||||
continue
|
||||
}
|
||||
merged[k] = v
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
// printFeatureFlags logs the currently used feature flags on server startup.
|
||||
func (s *Server) printFeatureFlags(o *Options) {
|
||||
if len(o.FeatureFlags) == 0 {
|
||||
return
|
||||
}
|
||||
keys := slices.Sorted(maps.Keys(o.FeatureFlags))
|
||||
|
||||
var (
|
||||
configured strings.Builder
|
||||
unsupported strings.Builder
|
||||
)
|
||||
|
||||
for _, k := range keys {
|
||||
// Unsupported
|
||||
defaultValue, ok := featureFlags[k]
|
||||
if !ok {
|
||||
if unsupported.Len() > 0 {
|
||||
unsupported.WriteString(", ")
|
||||
}
|
||||
unsupported.WriteString(k)
|
||||
continue
|
||||
}
|
||||
|
||||
v := o.FeatureFlags[k]
|
||||
if configured.Len() > 0 {
|
||||
configured.WriteString(", ")
|
||||
}
|
||||
configured.WriteString(k)
|
||||
configured.WriteString(" (")
|
||||
if defaultValue {
|
||||
if v {
|
||||
configured.WriteString("enabled")
|
||||
} else {
|
||||
configured.WriteString("opt-out")
|
||||
}
|
||||
} else if v {
|
||||
configured.WriteString("opt-in")
|
||||
} else {
|
||||
configured.WriteString("disabled")
|
||||
}
|
||||
configured.WriteString(")")
|
||||
}
|
||||
if configured.Len() == 0 {
|
||||
configured.WriteString("none")
|
||||
}
|
||||
|
||||
s.Noticef(" Feature flags:")
|
||||
s.Noticef(" Configured: %s", configured.String())
|
||||
if unsupported.Len() > 0 {
|
||||
s.Noticef(" Unsupported: %s", unsupported.String())
|
||||
}
|
||||
}
|
||||
+1909
-651
File diff suppressed because it is too large
Load Diff
+14
-6
@@ -1156,9 +1156,7 @@ func (c *client) processGatewayInfo(info *Info) {
|
||||
// defensive code above that if we did not register this connection
|
||||
// because we already have an outbound for this name, then
|
||||
// close this connection (and make sure it does not try to reconnect)
|
||||
c.mu.Lock()
|
||||
c.flags.set(noReconnect)
|
||||
c.mu.Unlock()
|
||||
c.setNoReconnect()
|
||||
c.closeConnection(WrongGateway)
|
||||
return
|
||||
}
|
||||
@@ -1981,7 +1979,7 @@ func (c *client) processGatewayRUnsub(arg []byte) error {
|
||||
return nil
|
||||
} else {
|
||||
// Plain sub, assume optimistic sends, create entry.
|
||||
e = &outsie{ni: make(map[string]struct{}), sl: NewSublistWithCache()}
|
||||
e = &outsie{ni: make(map[string]struct{}), sl: NewSublistForServer(c.srv)}
|
||||
newe = true
|
||||
}
|
||||
// This is when a sub or queue sub is supposed to be in
|
||||
@@ -2090,7 +2088,7 @@ func (c *client) processGatewayRSub(arg []byte) error {
|
||||
} else if queue == nil {
|
||||
return nil
|
||||
} else {
|
||||
e = &outsie{ni: make(map[string]struct{}), sl: NewSublistWithCache()}
|
||||
e = &outsie{ni: make(map[string]struct{}), sl: NewSublistForServer(c.srv)}
|
||||
newe = true
|
||||
useSl = true
|
||||
}
|
||||
@@ -2952,6 +2950,16 @@ func getSubjectFromGWRoutedReply(reply []byte, isOldPrefix bool) []byte {
|
||||
return reply[gwSubjectOffset:]
|
||||
}
|
||||
|
||||
// Returns the subject embedded in the given routed
|
||||
// reply subject and whether the prefix was stripped.
|
||||
// If the subject is not routed, returns it unchanged.
|
||||
func getGWRoutedSubjectOrSelf(subject []byte) ([]byte, bool) {
|
||||
if isGWPrefix, oldPrefix := isGWRoutedSubjectAndIsOldPrefix(subject); isGWPrefix {
|
||||
return getSubjectFromGWRoutedReply(subject, oldPrefix), true
|
||||
}
|
||||
return subject, false
|
||||
}
|
||||
|
||||
// This should be invoked only from processInboundGatewayMsg() or
|
||||
// processInboundRoutedMsg() and is checking if the subject
|
||||
// (c.pa.subject) has the _GR_ prefix. If so, this is processed
|
||||
@@ -3201,7 +3209,7 @@ func (c *client) gatewayAllSubsReceiveStart(info *Info) {
|
||||
e.mode = Transitioning
|
||||
e.Unlock()
|
||||
} else {
|
||||
e := &outsie{sl: NewSublistWithCache()}
|
||||
e := &outsie{sl: NewSublistForServer(c.srv)}
|
||||
e.mode = Transitioning
|
||||
c.mu.Lock()
|
||||
c.gw.outsim.Store(account, e)
|
||||
|
||||
+60
@@ -170,6 +170,66 @@ func (s *GenericSublist[T]) NumInterest(subject string) (np int) {
|
||||
return
|
||||
}
|
||||
|
||||
// MatchesFullWildcard returns true if there is top-level ">" interest.
|
||||
func (s *GenericSublist[T]) MatchesFullWildcard() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
return s.root.fwc != nil
|
||||
}
|
||||
|
||||
// MatchesSingleFilter returns the filter when the sublist contains exactly one unique subject.
|
||||
func (s *GenericSublist[T]) MatchesSingleFilter() (string, bool) {
|
||||
if s == nil {
|
||||
return _EMPTY_, false
|
||||
}
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
return singleFilter(s.root, _EMPTY_)
|
||||
}
|
||||
|
||||
func singleFilter[T comparable](l *level[T], filter string) (string, bool) {
|
||||
if l == nil {
|
||||
return filter, filter != _EMPTY_
|
||||
}
|
||||
if len(l.nodes) > 1 {
|
||||
return _EMPTY_, false
|
||||
}
|
||||
var next *node[T]
|
||||
branches := 0
|
||||
if l.pwc != nil {
|
||||
next = l.pwc
|
||||
branches++
|
||||
}
|
||||
if l.fwc != nil {
|
||||
next = l.fwc
|
||||
branches++
|
||||
}
|
||||
for _, n := range l.nodes {
|
||||
next = n
|
||||
branches++
|
||||
}
|
||||
if branches != 1 {
|
||||
return _EMPTY_, false
|
||||
}
|
||||
for _, subj := range next.subs {
|
||||
filter = subj
|
||||
break
|
||||
}
|
||||
if next.next == nil {
|
||||
return filter, filter != _EMPTY_
|
||||
}
|
||||
if filter != _EMPTY_ {
|
||||
if next.next.numNodes() > 0 {
|
||||
return _EMPTY_, false
|
||||
}
|
||||
return filter, true
|
||||
}
|
||||
return singleFilter(next.next, filter)
|
||||
}
|
||||
|
||||
func (s *GenericSublist[T]) match(subject string, cb func(T), doLock bool) {
|
||||
tsa := [32]string{}
|
||||
tokens := tsa[:0]
|
||||
|
||||
+63
-423
@@ -25,13 +25,13 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/minio/highwayhash"
|
||||
"github.com/nats-io/nats-server/v2/server/gsl"
|
||||
"github.com/nats-io/nats-server/v2/server/sysmem"
|
||||
"github.com/nats-io/nats-server/v2/server/tpm"
|
||||
"github.com/nats-io/nkeys"
|
||||
@@ -103,22 +103,26 @@ type JetStreamAPIStats struct {
|
||||
// This is for internal accounting for JetStream for this server.
|
||||
type jetStream struct {
|
||||
// These are here first because of atomics on 32bit systems.
|
||||
apiInflight int64
|
||||
apiTotal int64
|
||||
apiErrors int64
|
||||
memReserved int64
|
||||
storeReserved int64
|
||||
memUsed int64
|
||||
storeUsed int64
|
||||
queueLimit int64
|
||||
clustered int32
|
||||
mu sync.RWMutex
|
||||
srv *Server
|
||||
config JetStreamConfig
|
||||
cluster *jetStreamCluster
|
||||
accounts map[string]*jsAccount
|
||||
apiSubs *Sublist
|
||||
started time.Time
|
||||
apiInflight int64
|
||||
apiTotal int64
|
||||
apiErrors int64
|
||||
memMax int64
|
||||
memReserved int64 // Requires JS lock to be held.
|
||||
memUsed int64
|
||||
storeMax int64
|
||||
storeReserved int64 // Requires JS lock to be held.
|
||||
storeUsed int64
|
||||
queueLimit int64
|
||||
infoQueueLimit int64
|
||||
clustered int32
|
||||
mu sync.RWMutex
|
||||
srv *Server
|
||||
config JetStreamConfig
|
||||
cluster *jetStreamCluster
|
||||
accounts map[string]*jsAccount
|
||||
apiSubs *Sublist
|
||||
infoSubs *gsl.SimpleSublist // Subjects for info-specific queue.
|
||||
started time.Time
|
||||
|
||||
// System level request to purge a stream move
|
||||
accountPurge *subscription
|
||||
@@ -150,14 +154,12 @@ type jsaStorage struct {
|
||||
// an internal sub for a stream, so we will direct link to the stream
|
||||
// and walk backwards as needed vs multiple hash lookups and locks, etc.
|
||||
type jsAccount struct {
|
||||
mu sync.RWMutex
|
||||
js *jetStream
|
||||
account *Account
|
||||
storeDir string
|
||||
inflight sync.Map
|
||||
streams map[string]*stream
|
||||
templates map[string]*streamTemplate // Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
store TemplateStore // Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
mu sync.RWMutex
|
||||
js *jetStream
|
||||
account *Account
|
||||
storeDir string
|
||||
inflight sync.Map
|
||||
streams map[string]*stream
|
||||
|
||||
// From server
|
||||
sendq *ipQueue[*pubMsg]
|
||||
@@ -415,15 +417,19 @@ func (s *Server) initJetStreamEncryption() (err error) {
|
||||
|
||||
// enableJetStream will start up the JetStream subsystem.
|
||||
func (s *Server) enableJetStream(cfg JetStreamConfig) error {
|
||||
js := &jetStream{srv: s, config: cfg, accounts: make(map[string]*jsAccount), apiSubs: NewSublistNoCache()}
|
||||
js := &jetStream{srv: s, config: cfg, accounts: make(map[string]*jsAccount), apiSubs: NewSublistNoCache(), infoSubs: gsl.NewSimpleSublist()}
|
||||
s.gcbMu.Lock()
|
||||
if s.gcbOutMax = s.getOpts().JetStreamMaxCatchup; s.gcbOutMax == 0 {
|
||||
s.gcbOutMax = defaultMaxTotalCatchupOutBytes
|
||||
}
|
||||
s.gcbMu.Unlock()
|
||||
|
||||
atomic.StoreInt64(&js.memMax, cfg.MaxMemory)
|
||||
atomic.StoreInt64(&js.storeMax, cfg.MaxStore)
|
||||
|
||||
// TODO: Not currently reloadable.
|
||||
atomic.StoreInt64(&js.queueLimit, s.getOpts().JetStreamRequestQueueLimit)
|
||||
atomic.StoreInt64(&js.infoQueueLimit, s.getOpts().JetStreamInfoQueueLimit)
|
||||
|
||||
s.js.Store(js)
|
||||
|
||||
@@ -1058,8 +1064,10 @@ func (s *Server) shutdownJetStream() {
|
||||
func (s *Server) JetStreamConfig() *JetStreamConfig {
|
||||
var c *JetStreamConfig
|
||||
if js := s.getJetStream(); js != nil {
|
||||
js.mu.RLock()
|
||||
copy := js.config
|
||||
c = &(copy)
|
||||
js.mu.RUnlock()
|
||||
}
|
||||
return c
|
||||
}
|
||||
@@ -1219,54 +1227,6 @@ func (a *Account) EnableJetStream(limits map[string]JetStreamAccountLimits, tq c
|
||||
s.Debugf("Recovering JetStream state for account %q", a.Name)
|
||||
}
|
||||
|
||||
// Check templates first since messsage sets will need proper ownership.
|
||||
// FIXME(dlc) - Make this consistent.
|
||||
tdir := filepath.Join(jsa.storeDir, tmplsDir)
|
||||
if stat, err := os.Stat(tdir); err == nil && stat.IsDir() {
|
||||
key := sha256.Sum256([]byte("templates"))
|
||||
hh, err := highwayhash.NewDigest64(key[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fis, _ := os.ReadDir(tdir)
|
||||
for _, fi := range fis {
|
||||
metafile := filepath.Join(tdir, fi.Name(), JetStreamMetaFile)
|
||||
metasum := filepath.Join(tdir, fi.Name(), JetStreamMetaFileSum)
|
||||
buf, err := os.ReadFile(metafile)
|
||||
if err != nil {
|
||||
s.Warnf(" Error reading StreamTemplate metafile %q: %v", metasum, err)
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(metasum); os.IsNotExist(err) {
|
||||
s.Warnf(" Missing StreamTemplate checksum for %q", metasum)
|
||||
continue
|
||||
}
|
||||
sum, err := os.ReadFile(metasum)
|
||||
if err != nil {
|
||||
s.Warnf(" Error reading StreamTemplate checksum %q: %v", metasum, err)
|
||||
continue
|
||||
}
|
||||
hh.Reset()
|
||||
hh.Write(buf)
|
||||
var hb [highwayhash.Size64]byte
|
||||
checksum := hex.EncodeToString(hh.Sum(hb[:0]))
|
||||
if checksum != string(sum) {
|
||||
s.Warnf(" StreamTemplate checksums do not match %q vs %q", sum, checksum)
|
||||
continue
|
||||
}
|
||||
var cfg StreamTemplateConfig
|
||||
if err := json.Unmarshal(buf, &cfg); err != nil {
|
||||
s.Warnf(" Error unmarshalling StreamTemplate metafile: %v", err)
|
||||
continue
|
||||
}
|
||||
cfg.Config.Name = _EMPTY_
|
||||
if _, err := a.addStreamTemplate(&cfg); err != nil {
|
||||
s.Warnf(" Error recreating StreamTemplate %q: %v", cfg.Name, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remember if we should be encrypted and what cipher we think we should use.
|
||||
encrypted := s.getOpts().JetStreamKey != _EMPTY_
|
||||
sc := s.getOpts().JetStreamCipher
|
||||
@@ -1510,15 +1470,6 @@ func (a *Account) EnableJetStream(limits map[string]JetStreamAccountLimits, tq c
|
||||
return nil
|
||||
}
|
||||
|
||||
if cfg.Template != _EMPTY_ {
|
||||
jsa.mu.Lock()
|
||||
err := jsa.addStreamNameToTemplate(cfg.Template, cfg.Name)
|
||||
jsa.mu.Unlock()
|
||||
if err != nil {
|
||||
s.Warnf(" Error adding stream %q to template %q: %v", cfg.Name, cfg.Template, err)
|
||||
}
|
||||
}
|
||||
|
||||
// We had a bug that set a default de dupe window on mirror, despite that being not a valid config
|
||||
fixCfgMirrorWithDedupWindow(&cfg.StreamConfig)
|
||||
|
||||
@@ -1587,6 +1538,7 @@ func (a *Account) EnableJetStream(limits map[string]JetStreamAccountLimits, tq c
|
||||
batchId string
|
||||
batchSeq uint64
|
||||
commit bool
|
||||
commitEob bool
|
||||
batchStoreDir string
|
||||
store StreamStore
|
||||
state StreamState
|
||||
@@ -1604,19 +1556,30 @@ func (a *Account) EnableJetStream(limits map[string]JetStreamAccountLimits, tq c
|
||||
}
|
||||
// We've observed a partial batch write. Write the remainder of the batch.
|
||||
batchSeq++
|
||||
_, batchStoreDir = getBatchStoreDir(mset, batchId)
|
||||
_, batchStoreDir = getBatchStoreDir(jsa.storeDir, cfg.Name, batchId)
|
||||
if _, err = os.Stat(batchStoreDir); err != nil {
|
||||
s.Errorf(" Failed restoring partial batch write for stream '%s > %s' at sequence %d: %v",
|
||||
mset.accName(), mset.name(), batchSeq, err)
|
||||
goto SKIP
|
||||
}
|
||||
store, err = newBatchStore(mset, batchId)
|
||||
store, err = newBatchStore(mset, batchId, cfg.Replicas, cfg.Storage, jsa.storeDir, cfg.Name)
|
||||
if err != nil {
|
||||
s.Errorf(" Failed restoring partial batch write for stream '%s > %s' at sequence %d: %v",
|
||||
mset.accName(), mset.name(), batchSeq, err)
|
||||
goto SKIP
|
||||
}
|
||||
store.FastState(&state)
|
||||
sm, err = store.LoadMsg(state.LastSeq, &smv)
|
||||
if err != nil || sm == nil {
|
||||
s.Errorf(" Failed restoring partial batch write for stream '%s > %s' at sequence %d: last msg not found %d",
|
||||
mset.accName(), mset.name(), batchSeq, state.LastSeq)
|
||||
goto SKIP
|
||||
}
|
||||
commitEob = bytes.Equal(sliceHeader(JSBatchCommit, sm.hdr), []byte("eob"))
|
||||
// If the commit ends with an "End Of Batch" message, we don't store this.
|
||||
if commitEob {
|
||||
state.LastSeq--
|
||||
}
|
||||
s.Noticef(" Restoring partial batch write for stream '%s > %s' (seq %d to %d)",
|
||||
mset.accName(), mset.name(), batchSeq, state.LastSeq)
|
||||
// Loop through items that weren't persisted yet.
|
||||
@@ -1627,7 +1590,12 @@ func (a *Account) EnableJetStream(limits map[string]JetStreamAccountLimits, tq c
|
||||
mset.accName(), mset.name(), seq, err)
|
||||
break
|
||||
}
|
||||
mset.processJetStreamMsg(sm.subj, _EMPTY_, sm.hdr, sm.msg, 0, 0, nil, false, true)
|
||||
hdr := sm.hdr
|
||||
// If committed by EOB, the last message must get the normal commit header.
|
||||
if commitEob && seq == state.LastSeq {
|
||||
hdr = genHeader(hdr, JSBatchCommit, "1")
|
||||
}
|
||||
mset.processJetStreamMsg(sm.subj, _EMPTY_, hdr, sm.msg, 0, 0, nil, false, true)
|
||||
}
|
||||
store.Delete(true)
|
||||
SKIP:
|
||||
@@ -2342,14 +2310,14 @@ func (jsa *jsAccount) sendClusterUsageUpdate() {
|
||||
func (js *jetStream) wouldExceedLimits(storeType StorageType, sz int) bool {
|
||||
var (
|
||||
total *int64
|
||||
max int64
|
||||
max *int64
|
||||
)
|
||||
if storeType == MemoryStorage {
|
||||
total, max = &js.memUsed, js.config.MaxMemory
|
||||
total, max = &js.memUsed, &js.memMax
|
||||
} else {
|
||||
total, max = &js.storeUsed, js.config.MaxStore
|
||||
total, max = &js.storeUsed, &js.storeMax
|
||||
}
|
||||
return (atomic.LoadInt64(total) + int64(sz)) > max
|
||||
return (atomic.LoadInt64(total) + int64(sz)) > atomic.LoadInt64(max)
|
||||
}
|
||||
|
||||
func (js *jetStream) limitsExceeded(storeType StorageType) bool {
|
||||
@@ -2519,7 +2487,6 @@ func (jsa *jsAccount) acc() *Account {
|
||||
// Delete the JetStream resources.
|
||||
func (jsa *jsAccount) delete() {
|
||||
var streams []*stream
|
||||
var ts []string
|
||||
|
||||
jsa.mu.Lock()
|
||||
// The update timer and subs need to be protected by usageMu lock
|
||||
@@ -2538,20 +2505,11 @@ func (jsa *jsAccount) delete() {
|
||||
for _, ms := range jsa.streams {
|
||||
streams = append(streams, ms)
|
||||
}
|
||||
acc := jsa.account
|
||||
for _, t := range jsa.templates {
|
||||
ts = append(ts, t.Name)
|
||||
}
|
||||
jsa.templates = nil
|
||||
jsa.mu.Unlock()
|
||||
|
||||
for _, mset := range streams {
|
||||
mset.stop(false, false)
|
||||
}
|
||||
|
||||
for _, t := range ts {
|
||||
acc.deleteStreamTemplate(t)
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup the jetstream account for a given account.
|
||||
@@ -2763,325 +2721,6 @@ func (a *Account) checkForJetStream() (*Server, *jsAccount, error) {
|
||||
return s, jsa, nil
|
||||
}
|
||||
|
||||
// StreamTemplateConfig allows a configuration to auto-create streams based on this template when a message
|
||||
// is received that matches. Each new stream will use the config as the template config to create them.
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
type StreamTemplateConfig struct {
|
||||
Name string `json:"name"`
|
||||
Config *StreamConfig `json:"config"`
|
||||
MaxStreams uint32 `json:"max_streams"`
|
||||
}
|
||||
|
||||
// StreamTemplateInfo
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
type StreamTemplateInfo struct {
|
||||
Config *StreamTemplateConfig `json:"config"`
|
||||
Streams []string `json:"streams"`
|
||||
}
|
||||
|
||||
// streamTemplate
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
type streamTemplate struct {
|
||||
mu sync.Mutex
|
||||
tc *client
|
||||
jsa *jsAccount
|
||||
*StreamTemplateConfig
|
||||
streams []string
|
||||
}
|
||||
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
func (t *StreamTemplateConfig) deepCopy() *StreamTemplateConfig {
|
||||
copy := *t
|
||||
cfg := *t.Config
|
||||
copy.Config = &cfg
|
||||
return ©
|
||||
}
|
||||
|
||||
// addStreamTemplate will add a stream template to this account that allows auto-creation of streams.
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
func (a *Account) addStreamTemplate(tc *StreamTemplateConfig) (*streamTemplate, error) {
|
||||
s, jsa, err := a.checkForJetStream()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tc.Config.Name != "" {
|
||||
return nil, fmt.Errorf("template config name should be empty")
|
||||
}
|
||||
if len(tc.Name) > JSMaxNameLen {
|
||||
return nil, fmt.Errorf("template name is too long, maximum allowed is %d", JSMaxNameLen)
|
||||
}
|
||||
|
||||
// FIXME(dlc) - Hacky
|
||||
tcopy := tc.deepCopy()
|
||||
tcopy.Config.Name = "_"
|
||||
cfg, apiErr := s.checkStreamCfg(tcopy.Config, a, false)
|
||||
if apiErr != nil {
|
||||
return nil, apiErr
|
||||
}
|
||||
tcopy.Config = &cfg
|
||||
t := &streamTemplate{
|
||||
StreamTemplateConfig: tcopy,
|
||||
tc: s.createInternalJetStreamClient(),
|
||||
jsa: jsa,
|
||||
}
|
||||
t.tc.registerWithAccount(a)
|
||||
|
||||
jsa.mu.Lock()
|
||||
if jsa.templates == nil {
|
||||
jsa.templates = make(map[string]*streamTemplate)
|
||||
// Create the appropriate store
|
||||
if cfg.Storage == FileStorage {
|
||||
jsa.store = newTemplateFileStore(jsa.storeDir)
|
||||
} else {
|
||||
jsa.store = newTemplateMemStore()
|
||||
}
|
||||
} else if _, ok := jsa.templates[tcopy.Name]; ok {
|
||||
jsa.mu.Unlock()
|
||||
return nil, fmt.Errorf("template with name %q already exists", tcopy.Name)
|
||||
}
|
||||
jsa.templates[tcopy.Name] = t
|
||||
jsa.mu.Unlock()
|
||||
|
||||
// FIXME(dlc) - we can not overlap subjects between templates. Need to have test.
|
||||
|
||||
// Setup the internal subscriptions to trap the messages.
|
||||
if err := t.createTemplateSubscriptions(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := jsa.store.Store(t); err != nil {
|
||||
t.delete()
|
||||
return nil, err
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
func (t *streamTemplate) createTemplateSubscriptions() error {
|
||||
if t == nil {
|
||||
return fmt.Errorf("no template")
|
||||
}
|
||||
if t.tc == nil {
|
||||
return fmt.Errorf("template not enabled")
|
||||
}
|
||||
c := t.tc
|
||||
if !c.srv.EventsEnabled() {
|
||||
return ErrNoSysAccount
|
||||
}
|
||||
sid := 1
|
||||
for _, subject := range t.Config.Subjects {
|
||||
// Now create the subscription
|
||||
if _, err := c.processSub([]byte(subject), nil, []byte(strconv.Itoa(sid)), t.processInboundTemplateMsg, false); err != nil {
|
||||
c.acc.deleteStreamTemplate(t.Name)
|
||||
return err
|
||||
}
|
||||
sid++
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
func (t *streamTemplate) processInboundTemplateMsg(_ *subscription, pc *client, acc *Account, subject, reply string, msg []byte) {
|
||||
if t == nil || t.jsa == nil {
|
||||
return
|
||||
}
|
||||
jsa := t.jsa
|
||||
cn := canonicalName(subject)
|
||||
|
||||
jsa.mu.Lock()
|
||||
// If we already are registered then we can just return here.
|
||||
if _, ok := jsa.streams[cn]; ok {
|
||||
jsa.mu.Unlock()
|
||||
return
|
||||
}
|
||||
jsa.mu.Unlock()
|
||||
|
||||
// Check if we are at the maximum and grab some variables.
|
||||
t.mu.Lock()
|
||||
c := t.tc
|
||||
cfg := *t.Config
|
||||
cfg.Template = t.Name
|
||||
atLimit := len(t.streams) >= int(t.MaxStreams)
|
||||
if !atLimit {
|
||||
t.streams = append(t.streams, cn)
|
||||
}
|
||||
t.mu.Unlock()
|
||||
|
||||
if atLimit {
|
||||
c.RateLimitWarnf("JetStream could not create stream for account %q on subject %q, at limit", acc.Name, subject)
|
||||
return
|
||||
}
|
||||
|
||||
// We need to create the stream here.
|
||||
// Change the config from the template and only use literal subject.
|
||||
cfg.Name = cn
|
||||
cfg.Subjects = []string{subject}
|
||||
mset, err := acc.addStream(&cfg)
|
||||
if err != nil {
|
||||
acc.validateStreams(t)
|
||||
c.RateLimitWarnf("JetStream could not create stream for account %q on subject %q: %v", acc.Name, subject, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Process this message directly by invoking mset.
|
||||
mset.processInboundJetStreamMsg(nil, pc, acc, subject, reply, msg)
|
||||
}
|
||||
|
||||
// lookupStreamTemplate looks up the names stream template.
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
func (a *Account) lookupStreamTemplate(name string) (*streamTemplate, error) {
|
||||
_, jsa, err := a.checkForJetStream()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jsa.mu.Lock()
|
||||
defer jsa.mu.Unlock()
|
||||
if jsa.templates == nil {
|
||||
return nil, fmt.Errorf("template not found")
|
||||
}
|
||||
t, ok := jsa.templates[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("template not found")
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// This function will check all named streams and make sure they are valid.
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
func (a *Account) validateStreams(t *streamTemplate) {
|
||||
t.mu.Lock()
|
||||
var vstreams []string
|
||||
for _, sname := range t.streams {
|
||||
if _, err := a.lookupStream(sname); err == nil {
|
||||
vstreams = append(vstreams, sname)
|
||||
}
|
||||
}
|
||||
t.streams = vstreams
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
func (t *streamTemplate) delete() error {
|
||||
if t == nil {
|
||||
return fmt.Errorf("nil stream template")
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
jsa := t.jsa
|
||||
c := t.tc
|
||||
t.tc = nil
|
||||
defer func() {
|
||||
if c != nil {
|
||||
c.closeConnection(ClientClosed)
|
||||
}
|
||||
}()
|
||||
t.mu.Unlock()
|
||||
|
||||
if jsa == nil {
|
||||
return NewJSNotEnabledForAccountError()
|
||||
}
|
||||
|
||||
jsa.mu.Lock()
|
||||
if jsa.templates == nil {
|
||||
jsa.mu.Unlock()
|
||||
return fmt.Errorf("template not found")
|
||||
}
|
||||
if _, ok := jsa.templates[t.Name]; !ok {
|
||||
jsa.mu.Unlock()
|
||||
return fmt.Errorf("template not found")
|
||||
}
|
||||
delete(jsa.templates, t.Name)
|
||||
acc := jsa.account
|
||||
jsa.mu.Unlock()
|
||||
|
||||
// Remove streams associated with this template.
|
||||
var streams []*stream
|
||||
t.mu.Lock()
|
||||
for _, name := range t.streams {
|
||||
if mset, err := acc.lookupStream(name); err == nil {
|
||||
streams = append(streams, mset)
|
||||
}
|
||||
}
|
||||
t.mu.Unlock()
|
||||
|
||||
if jsa.store != nil {
|
||||
if err := jsa.store.Delete(t); err != nil {
|
||||
return fmt.Errorf("error deleting template from store: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, mset := range streams {
|
||||
if err := mset.delete(); err != nil {
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
func (a *Account) deleteStreamTemplate(name string) error {
|
||||
t, err := a.lookupStreamTemplate(name)
|
||||
if err != nil {
|
||||
return NewJSStreamTemplateNotFoundError()
|
||||
}
|
||||
return t.delete()
|
||||
}
|
||||
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
func (a *Account) templates() []*streamTemplate {
|
||||
var ts []*streamTemplate
|
||||
_, jsa, err := a.checkForJetStream()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
jsa.mu.Lock()
|
||||
for _, t := range jsa.templates {
|
||||
// FIXME(dlc) - Copy?
|
||||
ts = append(ts, t)
|
||||
}
|
||||
jsa.mu.Unlock()
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
// Will add a stream to a template, this is for recovery.
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
func (jsa *jsAccount) addStreamNameToTemplate(tname, mname string) error {
|
||||
if jsa.templates == nil {
|
||||
return fmt.Errorf("template not found")
|
||||
}
|
||||
t, ok := jsa.templates[tname]
|
||||
if !ok {
|
||||
return fmt.Errorf("template not found")
|
||||
}
|
||||
// We found template.
|
||||
t.mu.Lock()
|
||||
t.streams = append(t.streams, mname)
|
||||
t.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// This will check if a template owns this stream.
|
||||
// jsAccount lock should be held
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
func (jsa *jsAccount) checkTemplateOwnership(tname, sname string) bool {
|
||||
if jsa.templates == nil {
|
||||
return false
|
||||
}
|
||||
t, ok := jsa.templates[tname]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
// We found template, make sure we are in streams.
|
||||
for _, streamName := range t.streams {
|
||||
if sname == streamName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type Number interface {
|
||||
int | int8 | int16 | int32 | int64 | uint | uint8 | uint16 | uint32 | uint64 | float32 | float64
|
||||
}
|
||||
@@ -3107,10 +2746,11 @@ func isValidName(name string) bool {
|
||||
return !strings.ContainsAny(name, " \t\r\n\f.*>")
|
||||
}
|
||||
|
||||
// CanonicalName will replace all token separators '.' with '_'.
|
||||
// This can be used when naming streams or consumers with multi-token subjects.
|
||||
func canonicalName(name string) string {
|
||||
return strings.ReplaceAll(name, ".", "_")
|
||||
func isValidAssetName(name string) bool {
|
||||
if name == _EMPTY_ {
|
||||
return false
|
||||
}
|
||||
return !strings.ContainsAny(name, " \t\r\n\f.*>\\/")
|
||||
}
|
||||
|
||||
// To throttle the out of resources errors.
|
||||
|
||||
+175
-331
@@ -20,6 +20,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -46,29 +47,6 @@ const (
|
||||
// Will return JSON response.
|
||||
JSApiAccountInfo = "$JS.API.INFO"
|
||||
|
||||
// JSApiTemplateCreate is the endpoint to create new stream templates.
|
||||
// Will return JSON response.
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
JSApiTemplateCreate = "$JS.API.STREAM.TEMPLATE.CREATE.*"
|
||||
JSApiTemplateCreateT = "$JS.API.STREAM.TEMPLATE.CREATE.%s"
|
||||
|
||||
// JSApiTemplates is the endpoint to list all stream template names for this account.
|
||||
// Will return JSON response.
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
JSApiTemplates = "$JS.API.STREAM.TEMPLATE.NAMES"
|
||||
|
||||
// JSApiTemplateInfo is for obtaining general information about a named stream template.
|
||||
// Will return JSON response.
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
JSApiTemplateInfo = "$JS.API.STREAM.TEMPLATE.INFO.*"
|
||||
JSApiTemplateInfoT = "$JS.API.STREAM.TEMPLATE.INFO.%s"
|
||||
|
||||
// JSApiTemplateDelete is the endpoint to delete stream templates.
|
||||
// Will return JSON response.
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
JSApiTemplateDelete = "$JS.API.STREAM.TEMPLATE.DELETE.*"
|
||||
JSApiTemplateDeleteT = "$JS.API.STREAM.TEMPLATE.DELETE.%s"
|
||||
|
||||
// JSApiStreamCreate is the endpoint to create new streams.
|
||||
// Will return JSON response.
|
||||
JSApiStreamCreate = "$JS.API.STREAM.CREATE.*"
|
||||
@@ -177,6 +155,9 @@ const (
|
||||
// JSApiRequestNextT is the prefix for the request next message(s) for a consumer in worker/pull mode.
|
||||
JSApiRequestNextT = "$JS.API.CONSUMER.MSG.NEXT.%s.%s"
|
||||
|
||||
// JSApiConsumerResetT is the prefix for resetting a given consumer to a new starting sequence.
|
||||
JSApiConsumerResetT = "$JS.API.CONSUMER.RESET.%s.%s"
|
||||
|
||||
// JSApiConsumerUnpinT is the prefix for unpinning subscription for a given consumer.
|
||||
JSApiConsumerUnpin = "$JS.API.CONSUMER.UNPIN.*.*"
|
||||
JSApiConsumerUnpinT = "$JS.API.CONSUMER.UNPIN.%s.%s"
|
||||
@@ -237,13 +218,15 @@ const (
|
||||
// jsAckT is the template for the ack message stream coming back from a consumer
|
||||
// when they ACK/NAK, etc a message.
|
||||
jsAckT = "$JS.ACK.%s.%s"
|
||||
jsAckTv2 = "$JS.ACK.%s.%s.%s.%s"
|
||||
jsAckPre = "$JS.ACK."
|
||||
jsAckPreLen = len(jsAckPre)
|
||||
|
||||
// jsFlowControl is for flow control subjects.
|
||||
jsFlowControlPre = "$JS.FC."
|
||||
// jsFlowControl is for FC responses.
|
||||
jsFlowControl = "$JS.FC.%s.%s.*"
|
||||
jsFlowControl = "$JS.FC.%s.%s.*"
|
||||
jsFlowControlV2 = "$JS.FC.%s.%s.%s.%s.*"
|
||||
|
||||
// JSAdvisoryPrefix is a prefix for all JetStream advisories.
|
||||
JSAdvisoryPrefix = "$JS.EVENT.ADVISORY"
|
||||
@@ -787,50 +770,19 @@ type JSApiConsumerGetNextRequest struct {
|
||||
PriorityGroup
|
||||
}
|
||||
|
||||
// JSApiStreamTemplateCreateResponse for creating templates.
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
type JSApiStreamTemplateCreateResponse struct {
|
||||
// JSApiConsumerResetRequest is for resetting a consumer to a specific sequence.
|
||||
type JSApiConsumerResetRequest struct {
|
||||
Seq uint64 `json:"seq,omitempty"`
|
||||
}
|
||||
|
||||
// JSApiConsumerResetResponse is a superset of JSApiConsumerCreateResponse, but including an explicit ResetSeq.
|
||||
type JSApiConsumerResetResponse struct {
|
||||
ApiResponse
|
||||
*StreamTemplateInfo
|
||||
*ConsumerInfo
|
||||
ResetSeq uint64 `json:"reset_seq"`
|
||||
}
|
||||
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
const JSApiStreamTemplateCreateResponseType = "io.nats.jetstream.api.v1.stream_template_create_response"
|
||||
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
type JSApiStreamTemplateDeleteResponse struct {
|
||||
ApiResponse
|
||||
Success bool `json:"success,omitempty"`
|
||||
}
|
||||
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
const JSApiStreamTemplateDeleteResponseType = "io.nats.jetstream.api.v1.stream_template_delete_response"
|
||||
|
||||
// JSApiStreamTemplateInfoResponse for information about stream templates.
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
type JSApiStreamTemplateInfoResponse struct {
|
||||
ApiResponse
|
||||
*StreamTemplateInfo
|
||||
}
|
||||
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
const JSApiStreamTemplateInfoResponseType = "io.nats.jetstream.api.v1.stream_template_info_response"
|
||||
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
type JSApiStreamTemplatesRequest struct {
|
||||
ApiPagedRequest
|
||||
}
|
||||
|
||||
// JSApiStreamTemplateNamesResponse list of templates
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
type JSApiStreamTemplateNamesResponse struct {
|
||||
ApiResponse
|
||||
ApiPaged
|
||||
Templates []string `json:"streams"`
|
||||
}
|
||||
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
const JSApiStreamTemplateNamesResponseType = "io.nats.jetstream.api.v1.stream_template_names_response"
|
||||
const JSApiConsumerResetResponseType = "io.nats.jetstream.api.v1.consumer_reset_response"
|
||||
|
||||
// Structure that holds state for a JetStream API request that is processed
|
||||
// in a separate long-lived go routine. This is to avoid blocking connections.
|
||||
@@ -911,11 +863,19 @@ func (js *jetStream) apiDispatch(sub *subscription, c *client, acc *Account, sub
|
||||
// Copy the state. Note the JSAPI only uses the hdr index to piece apart the
|
||||
// header from the msg body. No other references are needed.
|
||||
// Check pending and warn if getting backed up.
|
||||
pending, _ := s.jsAPIRoutedReqs.push(&jsAPIRoutedReq{jsub, sub, acc, subject, reply, copyBytes(rmsg), c.pa})
|
||||
limit := atomic.LoadInt64(&js.queueLimit)
|
||||
var queue *ipQueue[*jsAPIRoutedReq]
|
||||
var limit int64
|
||||
if js.infoSubs.HasInterest(subject) {
|
||||
queue = s.jsAPIRoutedInfoReqs
|
||||
limit = atomic.LoadInt64(&js.infoQueueLimit)
|
||||
} else {
|
||||
queue = s.jsAPIRoutedReqs
|
||||
limit = atomic.LoadInt64(&js.queueLimit)
|
||||
}
|
||||
pending, _ := queue.push(&jsAPIRoutedReq{jsub, sub, acc, subject, reply, copyBytes(rmsg), c.pa})
|
||||
if pending >= int(limit) {
|
||||
s.rateLimitFormatWarnf("JetStream API queue limit reached, dropping %d requests", pending)
|
||||
drained := int64(s.jsAPIRoutedReqs.drain())
|
||||
s.rateLimitFormatWarnf("%s limit reached, dropping %d requests", queue.name, pending)
|
||||
drained := int64(queue.drain())
|
||||
atomic.AddInt64(&js.apiInflight, -drained)
|
||||
|
||||
s.publishAdvisory(nil, JSAdvisoryAPILimitReached, JSAPILimitReachedAdvisory{
|
||||
@@ -935,29 +895,45 @@ func (s *Server) processJSAPIRoutedRequests() {
|
||||
defer s.grWG.Done()
|
||||
|
||||
s.mu.RLock()
|
||||
queue := s.jsAPIRoutedReqs
|
||||
queue, infoqueue := s.jsAPIRoutedReqs, s.jsAPIRoutedInfoReqs
|
||||
client := &client{srv: s, kind: JETSTREAM}
|
||||
s.mu.RUnlock()
|
||||
|
||||
js := s.getJetStream()
|
||||
|
||||
processFromQueue := func(ipq *ipQueue[*jsAPIRoutedReq]) {
|
||||
// Only pop one item at a time here, otherwise if the system is recovering
|
||||
// from queue buildup, then one worker will pull off all the tasks and the
|
||||
// others will be starved of work.
|
||||
if r, ok := ipq.popOne(); ok && r != nil {
|
||||
client.pa = r.pa
|
||||
start := time.Now()
|
||||
r.jsub.icb(r.sub, client, r.acc, r.subject, r.reply, r.msg)
|
||||
if dur := time.Since(start); dur >= readLoopReportThreshold {
|
||||
s.Warnf("Internal subscription on %q took too long: %v", r.subject, dur)
|
||||
}
|
||||
atomic.AddInt64(&js.apiInflight, -1)
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
// First select case is prioritizing queue, we will only fall through
|
||||
// to the second select case that considers infoqueue if queue is empty.
|
||||
// This effectively means infos are deprioritized.
|
||||
select {
|
||||
case <-queue.ch:
|
||||
// Only pop one item at a time here, otherwise if the system is recovering
|
||||
// from queue buildup, then one worker will pull off all the tasks and the
|
||||
// others will be starved of work.
|
||||
for r, ok := queue.popOne(); ok && r != nil; r, ok = queue.popOne() {
|
||||
client.pa = r.pa
|
||||
start := time.Now()
|
||||
r.jsub.icb(r.sub, client, r.acc, r.subject, r.reply, r.msg)
|
||||
if dur := time.Since(start); dur >= readLoopReportThreshold {
|
||||
s.Warnf("Internal subscription on %q took too long: %v", r.subject, dur)
|
||||
}
|
||||
atomic.AddInt64(&js.apiInflight, -1)
|
||||
}
|
||||
processFromQueue(queue)
|
||||
case <-s.quitCh:
|
||||
return
|
||||
default:
|
||||
select {
|
||||
case <-infoqueue.ch:
|
||||
processFromQueue(infoqueue)
|
||||
case <-queue.ch:
|
||||
processFromQueue(queue)
|
||||
case <-s.quitCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -976,7 +952,8 @@ func (s *Server) setJetStreamExportSubs() error {
|
||||
if mp > maxProcs {
|
||||
mp = maxProcs
|
||||
}
|
||||
s.jsAPIRoutedReqs = newIPQueue[*jsAPIRoutedReq](s, "Routed JS API Requests")
|
||||
s.jsAPIRoutedReqs = newIPQueue[*jsAPIRoutedReq](s, "JetStream API queue")
|
||||
s.jsAPIRoutedInfoReqs = newIPQueue[*jsAPIRoutedReq](s, "JetStream API info queue")
|
||||
for i := 0; i < mp; i++ {
|
||||
s.startGoRoutine(s.processJSAPIRoutedRequests)
|
||||
}
|
||||
@@ -992,20 +969,13 @@ func (s *Server) setJetStreamExportSubs() error {
|
||||
}
|
||||
|
||||
// API handles themselves.
|
||||
// infopairs are deprioritized compared to pairs in processJSAPIRoutedRequests.
|
||||
pairs := []struct {
|
||||
subject string
|
||||
handler msgHandler
|
||||
}{
|
||||
{JSApiAccountInfo, s.jsAccountInfoRequest},
|
||||
{JSApiTemplateCreate, s.jsTemplateCreateRequest},
|
||||
{JSApiTemplates, s.jsTemplateNamesRequest},
|
||||
{JSApiTemplateInfo, s.jsTemplateInfoRequest},
|
||||
{JSApiTemplateDelete, s.jsTemplateDeleteRequest},
|
||||
{JSApiStreamCreate, s.jsStreamCreateRequest},
|
||||
{JSApiStreamUpdate, s.jsStreamUpdateRequest},
|
||||
{JSApiStreams, s.jsStreamNamesRequest},
|
||||
{JSApiStreamList, s.jsStreamListRequest},
|
||||
{JSApiStreamInfo, s.jsStreamInfoRequest},
|
||||
{JSApiStreamDelete, s.jsStreamDeleteRequest},
|
||||
{JSApiStreamPurge, s.jsStreamPurgeRequest},
|
||||
{JSApiStreamSnapshot, s.jsStreamSnapshotRequest},
|
||||
@@ -1018,23 +988,40 @@ func (s *Server) setJetStreamExportSubs() error {
|
||||
{JSApiConsumerCreateEx, s.jsConsumerCreateRequest},
|
||||
{JSApiConsumerCreate, s.jsConsumerCreateRequest},
|
||||
{JSApiDurableCreate, s.jsConsumerCreateRequest},
|
||||
{JSApiConsumers, s.jsConsumerNamesRequest},
|
||||
{JSApiConsumerList, s.jsConsumerListRequest},
|
||||
{JSApiConsumerInfo, s.jsConsumerInfoRequest},
|
||||
{JSApiConsumerDelete, s.jsConsumerDeleteRequest},
|
||||
{JSApiConsumerPause, s.jsConsumerPauseRequest},
|
||||
{JSApiConsumerUnpin, s.jsConsumerUnpinRequest},
|
||||
}
|
||||
infopairs := []struct {
|
||||
subject string
|
||||
handler msgHandler
|
||||
}{
|
||||
{JSApiAccountInfo, s.jsAccountInfoRequest},
|
||||
{JSApiStreams, s.jsStreamNamesRequest},
|
||||
{JSApiStreamList, s.jsStreamListRequest},
|
||||
{JSApiStreamInfo, s.jsStreamInfoRequest},
|
||||
{JSApiConsumers, s.jsConsumerNamesRequest},
|
||||
{JSApiConsumerList, s.jsConsumerListRequest},
|
||||
{JSApiConsumerInfo, s.jsConsumerInfoRequest},
|
||||
}
|
||||
|
||||
js.mu.Lock()
|
||||
defer js.mu.Unlock()
|
||||
|
||||
for _, p := range pairs {
|
||||
// As well as populating js.apiSubs for the dispatch function to use, we
|
||||
// will also populate js.infoSubs, so that the dispatch function can
|
||||
// decide quickly whether or not the request is an info request or not.
|
||||
for _, p := range append(infopairs, pairs...) {
|
||||
sub := &subscription{subject: []byte(p.subject), icb: p.handler}
|
||||
if err := js.apiSubs.Insert(sub); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, p := range infopairs {
|
||||
if err := js.infoSubs.Insert(p.subject, struct{}{}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1239,7 +1226,7 @@ func (s *Server) unmarshalRequest(c *client, acc *Account, subject string, msg [
|
||||
|
||||
c.RateLimitWarnf("Invalid JetStream request '%s > %s': %s", acc, subject, err)
|
||||
|
||||
if s.JetStreamConfig().Strict {
|
||||
if js := s.getJetStream(); js != nil && js.config.Strict {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1345,10 +1332,6 @@ func (s *Server) jsAccountInfoRequest(sub *subscription, c *client, _ *Account,
|
||||
}
|
||||
|
||||
// Helpers for token extraction.
|
||||
func templateNameFromSubject(subject string) string {
|
||||
return tokenAt(subject, 6)
|
||||
}
|
||||
|
||||
func streamNameFromSubject(subject string) string {
|
||||
return tokenAt(subject, 5)
|
||||
}
|
||||
@@ -1357,223 +1340,6 @@ func consumerNameFromSubject(subject string) string {
|
||||
return tokenAt(subject, 6)
|
||||
}
|
||||
|
||||
// Request to create a new template.
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
func (s *Server) jsTemplateCreateRequest(sub *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
ci, acc, hdr, msg, err := s.getRequestInfo(c, rmsg)
|
||||
if err != nil {
|
||||
s.Warnf(badAPIRequestT, msg)
|
||||
return
|
||||
}
|
||||
|
||||
var resp = JSApiStreamTemplateCreateResponse{ApiResponse: ApiResponse{Type: JSApiStreamTemplateCreateResponseType}}
|
||||
if errorOnRequiredApiLevel(hdr) {
|
||||
resp.Error = NewJSRequiredApiLevelError()
|
||||
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
|
||||
return
|
||||
}
|
||||
if !acc.JetStreamEnabled() {
|
||||
resp.Error = NewJSNotEnabledForAccountError()
|
||||
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
|
||||
return
|
||||
}
|
||||
|
||||
// Not supported for now.
|
||||
if s.JetStreamIsClustered() {
|
||||
resp.Error = NewJSClusterUnSupportFeatureError()
|
||||
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
|
||||
return
|
||||
}
|
||||
|
||||
var cfg StreamTemplateConfig
|
||||
if err := s.unmarshalRequest(c, acc, subject, msg, &cfg); err != nil {
|
||||
resp.Error = NewJSInvalidJSONError(err)
|
||||
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
|
||||
return
|
||||
}
|
||||
templateName := templateNameFromSubject(subject)
|
||||
if templateName != cfg.Name {
|
||||
resp.Error = NewJSTemplateNameNotMatchSubjectError()
|
||||
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
|
||||
return
|
||||
}
|
||||
|
||||
t, err := acc.addStreamTemplate(&cfg)
|
||||
if err != nil {
|
||||
resp.Error = NewJSStreamTemplateCreateError(err, Unless(err))
|
||||
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
|
||||
return
|
||||
}
|
||||
t.mu.Lock()
|
||||
tcfg := t.StreamTemplateConfig.deepCopy()
|
||||
streams := t.streams
|
||||
if streams == nil {
|
||||
streams = []string{}
|
||||
}
|
||||
t.mu.Unlock()
|
||||
resp.StreamTemplateInfo = &StreamTemplateInfo{Config: tcfg, Streams: streams}
|
||||
s.sendAPIResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(resp))
|
||||
}
|
||||
|
||||
// Request for the list of all template names.
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
func (s *Server) jsTemplateNamesRequest(sub *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
ci, acc, hdr, msg, err := s.getRequestInfo(c, rmsg)
|
||||
if err != nil {
|
||||
s.Warnf(badAPIRequestT, msg)
|
||||
return
|
||||
}
|
||||
|
||||
var resp = JSApiStreamTemplateNamesResponse{ApiResponse: ApiResponse{Type: JSApiStreamTemplateNamesResponseType}}
|
||||
if errorOnRequiredApiLevel(hdr) {
|
||||
resp.Error = NewJSRequiredApiLevelError()
|
||||
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
|
||||
return
|
||||
}
|
||||
if !acc.JetStreamEnabled() {
|
||||
resp.Error = NewJSNotEnabledForAccountError()
|
||||
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
|
||||
return
|
||||
}
|
||||
|
||||
// Not supported for now.
|
||||
if s.JetStreamIsClustered() {
|
||||
resp.Error = NewJSClusterUnSupportFeatureError()
|
||||
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
|
||||
return
|
||||
}
|
||||
|
||||
var offset int
|
||||
if isJSONObjectOrArray(msg) {
|
||||
var req JSApiStreamTemplatesRequest
|
||||
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
|
||||
resp.Error = NewJSInvalidJSONError(err)
|
||||
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
|
||||
return
|
||||
}
|
||||
offset = req.Offset
|
||||
}
|
||||
|
||||
ts := acc.templates()
|
||||
slices.SortFunc(ts, func(i, j *streamTemplate) int {
|
||||
return cmp.Compare(i.StreamTemplateConfig.Name, j.StreamTemplateConfig.Name)
|
||||
})
|
||||
|
||||
tcnt := len(ts)
|
||||
if offset > tcnt {
|
||||
offset = tcnt
|
||||
}
|
||||
|
||||
for _, t := range ts[offset:] {
|
||||
t.mu.Lock()
|
||||
name := t.Name
|
||||
t.mu.Unlock()
|
||||
resp.Templates = append(resp.Templates, name)
|
||||
if len(resp.Templates) >= JSApiNamesLimit {
|
||||
break
|
||||
}
|
||||
}
|
||||
resp.Total = tcnt
|
||||
resp.Limit = JSApiNamesLimit
|
||||
resp.Offset = offset
|
||||
if resp.Templates == nil {
|
||||
resp.Templates = []string{}
|
||||
}
|
||||
s.sendAPIResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(resp))
|
||||
}
|
||||
|
||||
// Request for information about a stream template.
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
func (s *Server) jsTemplateInfoRequest(sub *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
ci, acc, hdr, msg, err := s.getRequestInfo(c, rmsg)
|
||||
if err != nil {
|
||||
s.Warnf(badAPIRequestT, msg)
|
||||
return
|
||||
}
|
||||
|
||||
var resp = JSApiStreamTemplateInfoResponse{ApiResponse: ApiResponse{Type: JSApiStreamTemplateInfoResponseType}}
|
||||
if errorOnRequiredApiLevel(hdr) {
|
||||
resp.Error = NewJSRequiredApiLevelError()
|
||||
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
|
||||
return
|
||||
}
|
||||
if !acc.JetStreamEnabled() {
|
||||
resp.Error = NewJSNotEnabledForAccountError()
|
||||
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
|
||||
return
|
||||
}
|
||||
if !isEmptyRequest(msg) {
|
||||
resp.Error = NewJSNotEmptyRequestError()
|
||||
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
|
||||
return
|
||||
}
|
||||
name := templateNameFromSubject(subject)
|
||||
t, err := acc.lookupStreamTemplate(name)
|
||||
if err != nil {
|
||||
resp.Error = NewJSStreamTemplateNotFoundError()
|
||||
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
|
||||
return
|
||||
}
|
||||
t.mu.Lock()
|
||||
cfg := t.StreamTemplateConfig.deepCopy()
|
||||
streams := t.streams
|
||||
if streams == nil {
|
||||
streams = []string{}
|
||||
}
|
||||
t.mu.Unlock()
|
||||
|
||||
resp.StreamTemplateInfo = &StreamTemplateInfo{Config: cfg, Streams: streams}
|
||||
s.sendAPIResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(resp))
|
||||
}
|
||||
|
||||
// Request to delete a stream template.
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
func (s *Server) jsTemplateDeleteRequest(sub *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
ci, acc, hdr, msg, err := s.getRequestInfo(c, rmsg)
|
||||
if err != nil {
|
||||
s.Warnf(badAPIRequestT, msg)
|
||||
return
|
||||
}
|
||||
|
||||
var resp = JSApiStreamTemplateDeleteResponse{ApiResponse: ApiResponse{Type: JSApiStreamTemplateDeleteResponseType}}
|
||||
if errorOnRequiredApiLevel(hdr) {
|
||||
resp.Error = NewJSRequiredApiLevelError()
|
||||
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
|
||||
return
|
||||
}
|
||||
if !acc.JetStreamEnabled() {
|
||||
resp.Error = NewJSNotEnabledForAccountError()
|
||||
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
|
||||
return
|
||||
}
|
||||
if !isEmptyRequest(msg) {
|
||||
resp.Error = NewJSNotEmptyRequestError()
|
||||
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
|
||||
return
|
||||
}
|
||||
name := templateNameFromSubject(subject)
|
||||
err = acc.deleteStreamTemplate(name)
|
||||
if err != nil {
|
||||
resp.Error = NewJSStreamTemplateDeleteError(err, Unless(err))
|
||||
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
|
||||
return
|
||||
}
|
||||
resp.Success = true
|
||||
s.sendAPIResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(resp))
|
||||
}
|
||||
|
||||
func (s *Server) jsonResponse(v any) string {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
@@ -2109,7 +1875,7 @@ func (s *Server) jsStreamInfoRequest(sub *subscription, c *client, a *Account, s
|
||||
if cc != nil {
|
||||
// Check to make sure the stream is assigned.
|
||||
js.mu.RLock()
|
||||
isLeader, sa := cc.isLeader(), js.streamAssignment(acc.Name, streamName)
|
||||
isLeader, sa := cc.isLeader(), js.streamAssignmentOrInflight(acc.Name, streamName)
|
||||
var offline bool
|
||||
if sa != nil {
|
||||
clusterWideConsCount = len(sa.consumers)
|
||||
@@ -2338,7 +2104,7 @@ func (s *Server) jsStreamLeaderStepDownRequest(sub *subscription, c *client, _ *
|
||||
}
|
||||
|
||||
js.mu.RLock()
|
||||
isLeader, sa := cc.isLeader(), js.streamAssignment(acc.Name, name)
|
||||
isLeader, sa := cc.isLeader(), js.streamAssignmentOrInflight(acc.Name, name)
|
||||
js.mu.RUnlock()
|
||||
|
||||
if isLeader && sa == nil {
|
||||
@@ -2455,7 +2221,7 @@ func (s *Server) jsConsumerLeaderStepDownRequest(sub *subscription, c *client, _
|
||||
consumer := tokenAt(subject, 7)
|
||||
|
||||
js.mu.RLock()
|
||||
isLeader, sa := cc.isLeader(), js.streamAssignment(acc.Name, stream)
|
||||
isLeader, sa := cc.isLeader(), js.streamAssignmentOrInflight(acc.Name, stream)
|
||||
js.mu.RUnlock()
|
||||
|
||||
if isLeader && sa == nil {
|
||||
@@ -3456,7 +3222,7 @@ func (s *Server) jsMsgDeleteRequest(sub *subscription, c *client, _ *Account, su
|
||||
}
|
||||
|
||||
js.mu.RLock()
|
||||
isLeader, sa := cc.isLeader(), js.streamAssignment(acc.Name, stream)
|
||||
isLeader, sa := cc.isLeader(), js.streamAssignmentOrInflight(acc.Name, stream)
|
||||
js.mu.RUnlock()
|
||||
|
||||
if isLeader && sa == nil {
|
||||
@@ -3581,7 +3347,7 @@ func (s *Server) jsMsgGetRequest(sub *subscription, c *client, _ *Account, subje
|
||||
}
|
||||
|
||||
js.mu.RLock()
|
||||
isLeader, sa := cc.isLeader(), js.streamAssignment(acc.Name, stream)
|
||||
isLeader, sa := cc.isLeader(), js.streamAssignmentOrInflight(acc.Name, stream)
|
||||
js.mu.RUnlock()
|
||||
|
||||
if isLeader && sa == nil {
|
||||
@@ -3876,7 +3642,7 @@ func (s *Server) jsStreamPurgeRequest(sub *subscription, c *client, _ *Account,
|
||||
}
|
||||
|
||||
js.mu.RLock()
|
||||
isLeader, sa := cc.isLeader(), js.streamAssignment(acc.Name, stream)
|
||||
isLeader, sa := cc.isLeader(), js.streamAssignmentOrInflight(acc.Name, stream)
|
||||
js.mu.RUnlock()
|
||||
|
||||
if isLeader && sa == nil {
|
||||
@@ -4048,6 +3814,13 @@ func (s *Server) jsStreamRestoreRequest(sub *subscription, c *client, _ *Account
|
||||
return
|
||||
}
|
||||
|
||||
// Check for path like separators in the name.
|
||||
if strings.ContainsAny(stream, `\/`) {
|
||||
resp.Error = NewJSStreamNameContainsPathSeparatorsError()
|
||||
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
|
||||
return
|
||||
}
|
||||
|
||||
if s.JetStreamIsClustered() {
|
||||
s.jsClusteredStreamRestoreRequest(ci, acc, &req, subject, reply, rmsg)
|
||||
return
|
||||
@@ -4597,11 +4370,49 @@ func (s *Server) jsConsumerCreateRequest(sub *subscription, c *client, a *Accoun
|
||||
isClustered := s.JetStreamIsClustered()
|
||||
|
||||
// Determine if we should proceed here when we are in clustered mode.
|
||||
direct := req.Config.Direct
|
||||
if isClustered {
|
||||
if req.Config.Direct {
|
||||
// Check to see if we have this stream and are the stream leader.
|
||||
if !acc.JetStreamIsStreamLeader(streamNameFromSubject(subject)) {
|
||||
return
|
||||
if direct {
|
||||
// If it's just a direct consumer, check for stream leader.
|
||||
if !req.Config.Sourcing {
|
||||
// Check to see if we have this stream and are the stream leader.
|
||||
if !acc.JetStreamIsStreamLeader(streamNameFromSubject(subject)) {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Otherwise, we either need this to be answered by the stream or meta leader.
|
||||
var cc *jetStreamCluster
|
||||
js, cc = s.getJetStreamCluster()
|
||||
if js == nil || cc == nil {
|
||||
return
|
||||
}
|
||||
js.mu.RLock()
|
||||
sa := js.streamAssignmentOrInflight(acc.Name, streamNameFromSubject(subject))
|
||||
if sa == nil {
|
||||
js.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
// If the stream is WQ or Interest, we need the meta leader to answer.
|
||||
if sa.Config.Retention != LimitsPolicy {
|
||||
direct = false
|
||||
}
|
||||
js.mu.RUnlock()
|
||||
if direct {
|
||||
// Check to see if we have this stream and are the stream leader.
|
||||
if !acc.JetStreamIsStreamLeader(streamNameFromSubject(subject)) {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if js.isLeaderless() {
|
||||
resp.Error = NewJSClusterNotAvailError()
|
||||
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
|
||||
return
|
||||
}
|
||||
// Make sure we are meta leader.
|
||||
if !s.JetStreamIsLeader() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var cc *jetStreamCluster
|
||||
@@ -4645,6 +4456,7 @@ func (s *Server) jsConsumerCreateRequest(sub *subscription, c *client, a *Accoun
|
||||
// Legacy ephemeral.
|
||||
rt = ccLegacyEphemeral
|
||||
streamName = streamNameFromSubject(subject)
|
||||
consumerName = req.Config.Name
|
||||
} else {
|
||||
// New style and durable legacy.
|
||||
if tokenAt(subject, 4) == "DURABLE" {
|
||||
@@ -4736,7 +4548,7 @@ func (s *Server) jsConsumerCreateRequest(sub *subscription, c *client, a *Accoun
|
||||
return
|
||||
}
|
||||
|
||||
if isClustered && !req.Config.Direct {
|
||||
if isClustered && !direct {
|
||||
s.jsClusteredConsumerRequest(ci, acc, subject, reply, rmsg, req.Stream, &req.Config, req.Action, req.Pedantic)
|
||||
return
|
||||
}
|
||||
@@ -4760,6 +4572,23 @@ func (s *Server) jsConsumerCreateRequest(sub *subscription, c *client, a *Accoun
|
||||
return
|
||||
}
|
||||
|
||||
// If the consumer is a direct sourcing consumer, we need to "upgrade"
|
||||
// it to be durable without AckNone if not a Limits-based stream.
|
||||
if req.Config.Direct && req.Config.Sourcing && req.Config.Name != _EMPTY_ {
|
||||
if !isClustered && stream.isInterestRetention() {
|
||||
req.Config.Direct = false
|
||||
req.Config.Durable = req.Config.Name
|
||||
req.Config.AckPolicy = AckFlowControl
|
||||
req.Config.AckWait = 0
|
||||
req.Config.MaxDeliver = 0
|
||||
req.Config.InactiveThreshold = 0
|
||||
} else {
|
||||
// Otherwise, need to append a randomized suffix since the source uses a stable name.
|
||||
req.Config.Name = fmt.Sprintf("%s-%s", req.Config.Name, createConsumerName())
|
||||
consumerName = req.Config.Name
|
||||
}
|
||||
}
|
||||
|
||||
if o := stream.lookupConsumer(consumerName); o != nil {
|
||||
if o.offlineReason != _EMPTY_ {
|
||||
resp.Error = NewJSConsumerOfflineReasonError(errors.New(o.offlineReason))
|
||||
@@ -4770,6 +4599,12 @@ func (s *Server) jsConsumerCreateRequest(sub *subscription, c *client, a *Accoun
|
||||
// it back to whatever the current configured value is.
|
||||
o.mu.RLock()
|
||||
req.Config.PauseUntil = o.cfg.PauseUntil
|
||||
// If a durable sourcing consumer is used, we need to reset the deliver policy.
|
||||
if req.Config.Sourcing && req.Config.Durable != _EMPTY_ {
|
||||
req.Config.DeliverPolicy = o.cfg.DeliverPolicy
|
||||
req.Config.OptStartSeq = o.cfg.OptStartSeq
|
||||
req.Config.OptStartTime = o.cfg.OptStartTime
|
||||
}
|
||||
o.mu.RUnlock()
|
||||
}
|
||||
|
||||
@@ -5079,7 +4914,7 @@ func (s *Server) jsConsumerInfoRequest(sub *subscription, c *client, _ *Account,
|
||||
groupCreated := meta.Created()
|
||||
|
||||
js.mu.RLock()
|
||||
isLeader, sa, ca := cc.isLeader(), js.streamAssignment(acc.Name, streamName), js.consumerAssignment(acc.Name, streamName, consumerName)
|
||||
isLeader, sa, ca := cc.isLeader(), js.streamAssignmentOrInflight(acc.Name, streamName), js.consumerAssignmentOrInflight(acc.Name, streamName, consumerName)
|
||||
var rg *raftGroup
|
||||
var offline, isMember bool
|
||||
if ca != nil {
|
||||
@@ -5404,8 +5239,11 @@ func (s *Server) jsConsumerPauseRequest(sub *subscription, c *client, _ *Account
|
||||
return
|
||||
}
|
||||
|
||||
nca := *ca
|
||||
nca := ca.clone()
|
||||
// We're only holding the read lock and release below,
|
||||
// we need a copy to prevent concurrent reads/writes.
|
||||
ncfg := *ca.Config
|
||||
ncfg.Metadata = maps.Clone(ncfg.Metadata)
|
||||
nca.Config = &ncfg
|
||||
meta := cc.meta
|
||||
js.mu.RUnlock()
|
||||
@@ -5420,7 +5258,7 @@ func (s *Server) jsConsumerPauseRequest(sub *subscription, c *client, _ *Account
|
||||
// Only PauseUntil is updated above, so reuse config for both.
|
||||
setStaticConsumerMetadata(nca.Config)
|
||||
|
||||
eca := encodeAddConsumerAssignment(&nca)
|
||||
eca := encodeAddConsumerAssignment(nca)
|
||||
meta.Propose(eca)
|
||||
|
||||
resp.PauseUntil = pauseUTC
|
||||
@@ -5453,7 +5291,13 @@ func (s *Server) jsConsumerPauseRequest(sub *subscription, c *client, _ *Account
|
||||
return
|
||||
}
|
||||
|
||||
// We're only holding the read lock and release below,
|
||||
// we need a copy to prevent concurrent reads/writes.
|
||||
obs.mu.RLock()
|
||||
ncfg := obs.cfg
|
||||
ncfg.Metadata = maps.Clone(ncfg.Metadata)
|
||||
obs.mu.RUnlock()
|
||||
|
||||
pauseUTC := req.PauseUntil.UTC()
|
||||
if !pauseUTC.IsZero() {
|
||||
ncfg.PauseUntil = &pauseUTC
|
||||
|
||||
+455
-48
@@ -21,6 +21,7 @@ import (
|
||||
"math/big"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -29,60 +30,105 @@ import (
|
||||
|
||||
var (
|
||||
// Tracks the total inflight batches, across all streams and accounts that enable batching.
|
||||
globalInflightBatches atomic.Int32
|
||||
globalInflightAtomicBatches atomic.Int64
|
||||
globalInflightFastBatches atomic.Int64
|
||||
)
|
||||
|
||||
type batching struct {
|
||||
mu sync.Mutex
|
||||
group map[string]*batchGroup
|
||||
mu sync.Mutex
|
||||
atomic map[string]*atomicBatch
|
||||
fast map[string]*fastBatch
|
||||
}
|
||||
|
||||
type batchGroup struct {
|
||||
lseq uint64
|
||||
store StreamStore
|
||||
timer *time.Timer
|
||||
type atomicBatch struct {
|
||||
timer *time.Timer // Inactivity timer for the batch.
|
||||
lseq uint64 // The highest sequence for this batch.
|
||||
store StreamStore // Where the batch is staged before committing.
|
||||
}
|
||||
|
||||
type fastBatch struct {
|
||||
timer *time.Timer // Inactivity timer for the batch.
|
||||
lseq uint64 // The highest sequence for this batch.
|
||||
sseq uint64 // Last persisted stream sequence.
|
||||
pseq uint64 // Last persisted batch sequence (is always lower or equal to lseq).
|
||||
fseq uint64 // Sequence of when we last sent a flow message (is always lower or equal to pseq).
|
||||
pending uint32 // Number of pending messages in the batch waiting to be persisted.
|
||||
ackMessages uint16 // Ack will be sent every N messages.
|
||||
maxAckMessages uint16 // Maximum ackMessages value the client allows.
|
||||
reply string // The last reply subject seen when persisting a message.
|
||||
gapOk bool // Whether a gap is okay, if not, the batch would be rejected.
|
||||
commit bool // If the batch is committed.
|
||||
}
|
||||
|
||||
// newAtomicBatch creates an atomic batch publish object.
|
||||
// Lock should be held.
|
||||
func (batches *batching) newBatchGroup(mset *stream, batchId string) (*batchGroup, error) {
|
||||
store, err := newBatchStore(mset, batchId)
|
||||
func (batches *batching) newAtomicBatch(mset *stream, batchId string, replicas int, storage StorageType, storeDir, streamName string) (*atomicBatch, error) {
|
||||
store, err := newBatchStore(mset, batchId, replicas, storage, storeDir, streamName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b := &batchGroup{store: store}
|
||||
b := &atomicBatch{store: store}
|
||||
b.setupCleanupTimer(mset, batchId, batches)
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// setupCleanupTimer sets up a timer to clean up the batch after a timeout.
|
||||
func (b *atomicBatch) setupCleanupTimer(mset *stream, batchId string, batches *batching) {
|
||||
// Create a timer to clean up after timeout.
|
||||
timeout := streamMaxBatchTimeout
|
||||
if maxBatchTimeout := mset.srv.getOpts().JetStreamLimits.MaxBatchTimeout; maxBatchTimeout > 0 {
|
||||
timeout = maxBatchTimeout
|
||||
}
|
||||
timeout := getCleanupTimeout(mset)
|
||||
b.timer = time.AfterFunc(timeout, func() {
|
||||
b.cleanup(batchId, batches)
|
||||
mset.sendStreamBatchAbandonedAdvisory(batchId, BatchTimeout)
|
||||
})
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func getBatchStoreDir(mset *stream, batchId string) (string, string) {
|
||||
mset.mu.RLock()
|
||||
jsa, name := mset.jsa, mset.cfg.Name
|
||||
mset.mu.RUnlock()
|
||||
// resetCleanupTimer resets the cleanup timer, allowing to extend the lifetime of the batch.
|
||||
// Returns whether the timer was reset without it having expired before.
|
||||
func (b *atomicBatch) resetCleanupTimer(mset *stream) bool {
|
||||
timeout := getCleanupTimeout(mset)
|
||||
return b.timer.Reset(timeout)
|
||||
}
|
||||
|
||||
jsa.mu.RLock()
|
||||
sd := jsa.storeDir
|
||||
jsa.mu.RUnlock()
|
||||
// cleanup deletes underlying resources associated with the batch and unregisters it from the stream's batches.
|
||||
func (b *atomicBatch) cleanup(batchId string, batches *batching) {
|
||||
batches.mu.Lock()
|
||||
defer batches.mu.Unlock()
|
||||
b.cleanupLocked(batchId, batches)
|
||||
}
|
||||
|
||||
// Lock should be held.
|
||||
func (b *atomicBatch) cleanupLocked(batchId string, batches *batching) {
|
||||
if b.timer == nil {
|
||||
return
|
||||
}
|
||||
globalInflightAtomicBatches.Add(-1)
|
||||
b.timer.Stop()
|
||||
b.store.Delete(true)
|
||||
delete(batches.atomic, batchId)
|
||||
// Reset so that another invocation doesn't double-account.
|
||||
b.timer = nil
|
||||
}
|
||||
|
||||
// Lock should be held.
|
||||
func (b *atomicBatch) stopLocked() {
|
||||
if b.timer == nil {
|
||||
return
|
||||
}
|
||||
globalInflightAtomicBatches.Add(-1)
|
||||
b.timer.Stop()
|
||||
b.store.Stop()
|
||||
// Reset so that another invocation doesn't double-account.
|
||||
b.timer = nil
|
||||
}
|
||||
|
||||
func getBatchStoreDir(storeDir, streamName, batchId string) (string, string) {
|
||||
bname := getHash(batchId)
|
||||
return bname, filepath.Join(sd, streamsDir, name, batchesDir, bname)
|
||||
return bname, filepath.Join(storeDir, streamsDir, streamName, batchesDir, bname)
|
||||
}
|
||||
|
||||
func newBatchStore(mset *stream, batchId string) (StreamStore, error) {
|
||||
mset.mu.RLock()
|
||||
replicas, storage := mset.cfg.Replicas, mset.cfg.Storage
|
||||
mset.mu.RUnlock()
|
||||
|
||||
func newBatchStore(mset *stream, batchId string, replicas int, storage StorageType, storeDir, streamName string) (StreamStore, error) {
|
||||
if replicas == 1 && storage == FileStorage {
|
||||
bname, storeDir := getBatchStoreDir(mset, batchId)
|
||||
bname, storeDir := getBatchStoreDir(storeDir, streamName, batchId)
|
||||
fcfg := FileStoreConfig{AsyncFlush: true, BlockSize: defaultLargeBlockSize, StoreDir: storeDir}
|
||||
s := mset.srv
|
||||
prf := s.jsKeyGen(s.getOpts().JetStreamKey, mset.acc.Name)
|
||||
@@ -101,34 +147,264 @@ func newBatchStore(mset *stream, batchId string) (StreamStore, error) {
|
||||
// If the timer has already cleaned up the batch, we can't commit.
|
||||
// Otherwise, we ensure the timer does not clean up the batch in the meantime.
|
||||
// Lock should be held.
|
||||
func (b *batchGroup) readyForCommit() bool {
|
||||
func (b *atomicBatch) readyForCommit() *BatchAbandonReason {
|
||||
if !b.timer.Stop() {
|
||||
return &BatchTimeout
|
||||
}
|
||||
if b.store.FlushAllPending() != nil {
|
||||
return &BatchIncomplete
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// newFastBatch creates a fast batch publish object and registers it in batches.fast.
|
||||
// Lock should be held.
|
||||
func (batches *batching) newFastBatch(mset *stream, batchId string, gapOk bool, maxAckMessages uint16) *fastBatch {
|
||||
b := &fastBatch{gapOk: gapOk, maxAckMessages: maxAckMessages}
|
||||
if batches.fast == nil {
|
||||
batches.fast = make(map[string]*fastBatch, 1)
|
||||
}
|
||||
batches.fast[batchId] = b
|
||||
batches.fastBatchInit(b)
|
||||
b.setupCleanupTimer(mset, batchId, batches)
|
||||
return b
|
||||
}
|
||||
|
||||
// fastBatchInit (re)initializes the ackMessages field for a fast batch.
|
||||
// The batch must already be registered in batches.fast.
|
||||
// Lock should be held.
|
||||
func (batches *batching) fastBatchInit(b *fastBatch) {
|
||||
// If it's the only batch, just allow what the client wants, otherwise we'll
|
||||
// need to coordinate and slowly ramp up this publisher.
|
||||
// TODO(mvv): fast ingest's initial flow value improvements?
|
||||
ackMessages := min(500, b.maxAckMessages)
|
||||
if len(batches.fast) > 1 {
|
||||
ackMessages = 1
|
||||
}
|
||||
b.ackMessages = ackMessages
|
||||
}
|
||||
|
||||
// fastBatchReset resets the fast batch to an empty state and sends a flow control message.
|
||||
// Lock should be held.
|
||||
func (batches *batching) fastBatchReset(mset *stream, batchId string, b *fastBatch) {
|
||||
// If the timer already stopped before we could commit, we clean it up.
|
||||
if b.timer == nil || (!b.commit && !b.timer.Stop()) {
|
||||
b.cleanupLocked(batchId, batches)
|
||||
return
|
||||
}
|
||||
// Otherwise, reset the state.
|
||||
batches.fastBatchInit(b)
|
||||
b.timer.Reset(getCleanupTimeout(mset))
|
||||
b.commit = false
|
||||
b.pending = 0
|
||||
b.fseq, b.lseq = b.pseq, b.pseq
|
||||
b.sendFlowControl(b.fseq, mset, b.reply)
|
||||
}
|
||||
|
||||
// fastBatchRegisterSequences registers the highest stored batch and stream sequence and returns
|
||||
// whether a PubAck should be sent if the batch has been committed.
|
||||
// If this is called on a follower, it only registers the highest stream and persisted batch sequences.
|
||||
// Lock should be held.
|
||||
func (batches *batching) fastBatchRegisterSequences(mset *stream, reply string, streamSeq uint64, isLeader bool, batch *FastBatch) bool {
|
||||
b, ok := batches.fast[batch.id]
|
||||
if !ok || !isLeader {
|
||||
// If this batch has committed, we can clean it up.
|
||||
if batch.commit {
|
||||
if b != nil {
|
||||
b.cleanupLocked(batch.id, batches)
|
||||
}
|
||||
return false
|
||||
}
|
||||
// Otherwise, even as a follower, we record the latest state of this batch.
|
||||
if b == nil || !b.resetCleanupTimer(mset) {
|
||||
if b != nil {
|
||||
// The timer couldn't be reset, this means the timer already runs and is likely
|
||||
// waiting to acquire the lock. We reset the timer here so it doesn't clean up
|
||||
// this batch that we're about to overwrite.
|
||||
b.timer = nil
|
||||
} else {
|
||||
// If this is a new batch for us, even though we're a follower, we still need
|
||||
// to account toward the global inflight limit.
|
||||
globalInflightFastBatches.Add(1)
|
||||
}
|
||||
// We'll need a copy as we'll use it as a key and later for cleanup.
|
||||
batchId := copyString(batch.id)
|
||||
b = batches.newFastBatch(mset, batchId, batch.gapOk, batch.flow)
|
||||
}
|
||||
b.sseq = streamSeq
|
||||
b.pseq, b.lseq = batch.seq, batch.seq
|
||||
b.reply = reply
|
||||
return false
|
||||
}
|
||||
b.store.FlushAllPending()
|
||||
b.reply = reply
|
||||
if b.pending > 0 {
|
||||
b.pending--
|
||||
}
|
||||
b.sseq = streamSeq
|
||||
// Store last persisted batch sequence.
|
||||
// If we have no remaining pending writes, we might have had duplicate messages
|
||||
// and need to send additional flow control messages.
|
||||
var skipped bool
|
||||
if b.pending == 0 {
|
||||
skipped = true
|
||||
b.pseq = b.lseq
|
||||
} else {
|
||||
b.pseq = batch.seq
|
||||
}
|
||||
// If the PubAck needs to be sent now as a result of a commit.
|
||||
if b.lseq == b.pseq && b.commit {
|
||||
b.cleanupLocked(batch.id, batches)
|
||||
// If we skipped ahead due to duplicate messages, send the PubAck with the highest sequence.
|
||||
if skipped {
|
||||
var buf [256]byte
|
||||
pubAck := append(buf[:0], mset.pubAck...)
|
||||
response := append(pubAck, strconv.FormatUint(b.sseq, 10)...)
|
||||
response = append(response, fmt.Sprintf(",\"batch\":%q,\"count\":%d}", batch.id, b.lseq)...)
|
||||
if len(reply) > 0 {
|
||||
mset.outq.sendMsg(reply, response)
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
b.checkFlowControl(mset, reply, batches)
|
||||
return false
|
||||
}
|
||||
|
||||
// checkFlowControl checks whether a flow control message should be sent.
|
||||
// If so, it updates the flow values to speed up or slow down the publisher if needed.
|
||||
// Returns whether a flow control message was sent.
|
||||
// Lock should be held.
|
||||
func (b *fastBatch) checkFlowControl(mset *stream, reply string, batches *batching) bool {
|
||||
am := uint64(b.ackMessages)
|
||||
if b.pseq < b.fseq+am {
|
||||
return false
|
||||
}
|
||||
// Instead of sending multiple flow control messages, skip ahead to only send the last.
|
||||
steps := (b.pseq - b.fseq) / am
|
||||
b.fseq += steps * am
|
||||
|
||||
// TODO(mvv): fast ingest's dynamic flow value improvements?
|
||||
// This is currently just a simple value to have a working version. Should take average
|
||||
// message sizes into account and compare how much this client is contributing to the
|
||||
// ingest IPQ total size and messages and have publishers share based on that.
|
||||
maxAckMessages := uint16(500 / len(batches.fast))
|
||||
if maxAckMessages < 1 {
|
||||
maxAckMessages = 1
|
||||
}
|
||||
// Limit to the client's allowed maximum.
|
||||
if maxAckMessages > b.maxAckMessages {
|
||||
maxAckMessages = b.maxAckMessages
|
||||
}
|
||||
|
||||
if b.ackMessages < maxAckMessages {
|
||||
// Ramp up.
|
||||
b.ackMessages *= 2
|
||||
if b.ackMessages > maxAckMessages {
|
||||
b.ackMessages = maxAckMessages
|
||||
}
|
||||
} else if b.ackMessages > maxAckMessages {
|
||||
// Slow down.
|
||||
b.ackMessages /= 2
|
||||
if b.ackMessages <= maxAckMessages {
|
||||
b.ackMessages = maxAckMessages
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, send the flow control message.
|
||||
b.sendFlowControl(b.fseq, mset, reply)
|
||||
return true
|
||||
}
|
||||
|
||||
// sendFlowControl sends a fast batch flow control message for the current highest sequence.
|
||||
// Lock should be held.
|
||||
func (b *fastBatch) sendFlowControl(batchSeq uint64, mset *stream, reply string) {
|
||||
if len(reply) == 0 {
|
||||
return
|
||||
}
|
||||
response, _ := BatchFlowAck{Sequence: batchSeq, Messages: b.ackMessages}.MarshalJSON()
|
||||
mset.outq.sendMsg(reply, response)
|
||||
}
|
||||
|
||||
// fastBatchCommit ends the batch and commits the data up to that point. If all messages
|
||||
// have already been persisted, a PubAck is sent immediately. Otherwise, it will be sent
|
||||
// after the last message has been persisted.
|
||||
// Lock should be held.
|
||||
func (batches *batching) fastBatchCommit(b *fastBatch, batchId string, mset *stream, reply string) bool {
|
||||
// Either we commit now, or we clean up later, so stop the timer.
|
||||
if b.timer == nil || (!b.commit && !b.timer.Stop()) {
|
||||
// Shouldn't be possible for the timer to already be stopped if we haven't committed yet,
|
||||
// since we pre-check being able to reset the timer. But guard against it anyhow.
|
||||
return true
|
||||
}
|
||||
// Mark that this batch commits.
|
||||
b.commit = true
|
||||
// If the whole batch has been persisted, we can respond with the PubAck now.
|
||||
if b.lseq == b.pseq {
|
||||
b.cleanupLocked(batchId, batches)
|
||||
var buf [256]byte
|
||||
pubAck := append(buf[:0], mset.pubAck...)
|
||||
response := append(pubAck, strconv.FormatUint(b.sseq, 10)...)
|
||||
response = append(response, fmt.Sprintf(",\"batch\":%q,\"count\":%d}", batchId, b.lseq)...)
|
||||
if len(reply) > 0 {
|
||||
mset.outq.sendMsg(reply, response)
|
||||
}
|
||||
return true
|
||||
}
|
||||
// Otherwise, we need to wait and the PubAck will be sent when the last message is persisted.
|
||||
return false
|
||||
}
|
||||
|
||||
// setupCleanupTimer sets up a timer to clean up the batch after a timeout.
|
||||
func (b *fastBatch) setupCleanupTimer(mset *stream, batchId string, batches *batching) {
|
||||
// Create a timer to clean up after timeout.
|
||||
timeout := getCleanupTimeout(mset)
|
||||
b.timer = time.AfterFunc(timeout, func() {
|
||||
b.cleanup(batchId, batches)
|
||||
})
|
||||
}
|
||||
|
||||
// resetCleanupTimer resets the cleanup timer, allowing to extend the lifetime of the batch.
|
||||
// Returns whether the timer was reset without it having expired before.
|
||||
func (b *fastBatch) resetCleanupTimer(mset *stream) bool {
|
||||
if b.commit {
|
||||
return true
|
||||
}
|
||||
if b.timer == nil {
|
||||
return false
|
||||
}
|
||||
timeout := getCleanupTimeout(mset)
|
||||
return b.timer.Reset(timeout)
|
||||
}
|
||||
|
||||
// cleanup deletes underlying resources associated with the batch and unregisters it from the stream's batches.
|
||||
func (b *batchGroup) cleanup(batchId string, batches *batching) {
|
||||
func (b *fastBatch) cleanup(batchId string, batches *batching) {
|
||||
batches.mu.Lock()
|
||||
defer batches.mu.Unlock()
|
||||
b.cleanupLocked(batchId, batches)
|
||||
}
|
||||
|
||||
// Lock should be held.
|
||||
func (b *batchGroup) cleanupLocked(batchId string, batches *batching) {
|
||||
globalInflightBatches.Add(-1)
|
||||
func (b *fastBatch) cleanupLocked(batchId string, batches *batching) {
|
||||
// If the timer is nil, it means this batch has been replaced with a new one.
|
||||
// This can happen on a follower depending on timing.
|
||||
if b.timer == nil {
|
||||
return
|
||||
}
|
||||
globalInflightFastBatches.Add(-1)
|
||||
b.timer.Stop()
|
||||
b.store.Delete(true)
|
||||
delete(batches.group, batchId)
|
||||
delete(batches.fast, batchId)
|
||||
// Reset so that another invocation doesn't double-account.
|
||||
b.timer = nil
|
||||
}
|
||||
|
||||
// Lock should be held.
|
||||
func (b *batchGroup) stopLocked() {
|
||||
globalInflightBatches.Add(-1)
|
||||
b.timer.Stop()
|
||||
b.store.Stop()
|
||||
// getCleanupTimeout returns the timeout for the batch, taking into account the server's limits.
|
||||
func getCleanupTimeout(mset *stream) time.Duration {
|
||||
timeout := streamMaxBatchTimeout
|
||||
if maxBatchTimeout := mset.srv.getOpts().JetStreamLimits.MaxBatchTimeout; maxBatchTimeout > 0 {
|
||||
timeout = maxBatchTimeout
|
||||
}
|
||||
return timeout
|
||||
}
|
||||
|
||||
// batchStagedDiff stages all changes for consistency checks until commit.
|
||||
@@ -136,6 +412,7 @@ type batchStagedDiff struct {
|
||||
msgIds map[string]struct{}
|
||||
counter map[string]*msgCounterRunningTotal
|
||||
inflight map[string]*inflightSubjectRunningTotal
|
||||
inflightTransform map[uint64]string
|
||||
expectedPerSubject map[string]*batchExpectedPerSubject
|
||||
}
|
||||
|
||||
@@ -180,6 +457,16 @@ func (diff *batchStagedDiff) commit(mset *stream) {
|
||||
}
|
||||
}
|
||||
|
||||
// Track inflight subject transforms.
|
||||
if len(diff.inflightTransform) > 0 {
|
||||
if mset.inflightTransform == nil {
|
||||
mset.inflightTransform = make(map[uint64]string, len(diff.inflightTransform))
|
||||
}
|
||||
for clseq, subj := range diff.inflightTransform {
|
||||
mset.inflightTransform[clseq] = subj
|
||||
}
|
||||
}
|
||||
|
||||
// Track sequence and subject.
|
||||
if len(diff.expectedPerSubject) > 0 {
|
||||
if mset.expectedPerSubjectSequence == nil {
|
||||
@@ -238,7 +525,7 @@ func (batch *batchApply) rejectBatchState(mset *stream) {
|
||||
// mset.mu lock must NOT be held or used.
|
||||
// mset.clMu lock must be held.
|
||||
func checkMsgHeadersPreClusteredProposal(
|
||||
diff *batchStagedDiff, mset *stream, subject string, hdr []byte, msg []byte, sourced bool, name string,
|
||||
diff *batchStagedDiff, mset *stream, subject, rsubject string, hdr []byte, msg []byte, sourced bool, name string,
|
||||
jsa *jsAccount, allowRollup, denyPurge, allowTTL, allowMsgCounter, allowMsgSchedules bool,
|
||||
discard DiscardPolicy, discardNewPer bool, maxMsgSize int, maxMsgs int64, maxMsgsPer int64, maxBytes int64,
|
||||
) ([]byte, []byte, uint64, *ApiError, error) {
|
||||
@@ -515,8 +802,9 @@ func checkMsgHeadersPreClusteredProposal(
|
||||
}
|
||||
|
||||
// Message scheduling.
|
||||
if schedule, ok := getMessageSchedule(hdr); !ok {
|
||||
apiErr := NewJSMessageSchedulesPatternInvalidError()
|
||||
if sourced {
|
||||
// noop, sourced messages were already validated by the origin stream.
|
||||
} else if schedule, apiErr := getMessageSchedule(hdr); apiErr != nil {
|
||||
if !allowMsgSchedules {
|
||||
apiErr = NewJSMessageSchedulesDisabledError()
|
||||
}
|
||||
@@ -528,22 +816,40 @@ func checkMsgHeadersPreClusteredProposal(
|
||||
} else if scheduleTtl, ok := getMessageScheduleTTL(hdr); !ok {
|
||||
apiErr := NewJSMessageSchedulesTTLInvalidError()
|
||||
return hdr, msg, 0, apiErr, apiErr
|
||||
} else if scheduleRollup := getMessageScheduleRollup(hdr); scheduleRollup != _EMPTY_ && scheduleRollup != JSMsgRollupSubject {
|
||||
apiErr := NewJSMessageSchedulesRollupInvalidError()
|
||||
return hdr, msg, 0, apiErr, apiErr
|
||||
} else if scheduleTtl != _EMPTY_ && !allowTTL {
|
||||
return hdr, msg, 0, NewJSMessageTTLDisabledError(), errMsgTTLDisabled
|
||||
} else if scheduleTarget := getMessageScheduleTarget(hdr); scheduleTarget == _EMPTY_ ||
|
||||
!IsValidPublishSubject(scheduleTarget) || SubjectsCollide(scheduleTarget, subject) {
|
||||
!IsValidPublishSubject(scheduleTarget) || scheduleTarget == subject {
|
||||
apiErr := NewJSMessageSchedulesTargetInvalidError()
|
||||
return hdr, msg, 0, apiErr, apiErr
|
||||
} else if scheduleSource := getMessageScheduleSource(hdr); scheduleSource != _EMPTY_ &&
|
||||
(scheduleSource == scheduleTarget || scheduleSource == subject || !IsValidPublishSubject(scheduleSource)) {
|
||||
apiErr := NewJSMessageSchedulesSourceInvalidError()
|
||||
return hdr, msg, 0, apiErr, apiErr
|
||||
} else {
|
||||
mset.cfgMu.RLock()
|
||||
match := slices.ContainsFunc(mset.cfg.Subjects, func(subj string) bool {
|
||||
return SubjectsCollide(subj, scheduleTarget)
|
||||
})
|
||||
mset.cfgMu.RUnlock()
|
||||
if !match {
|
||||
mset.cfgMu.RUnlock()
|
||||
apiErr := NewJSMessageSchedulesTargetInvalidError()
|
||||
return hdr, msg, 0, apiErr, apiErr
|
||||
}
|
||||
if scheduleSource != _EMPTY_ {
|
||||
match = slices.ContainsFunc(mset.cfg.Subjects, func(subj string) bool {
|
||||
return SubjectsCollide(subj, scheduleSource)
|
||||
})
|
||||
if !match {
|
||||
mset.cfgMu.RUnlock()
|
||||
apiErr := NewJSMessageSchedulesSourceInvalidError()
|
||||
return hdr, msg, 0, apiErr, apiErr
|
||||
}
|
||||
}
|
||||
mset.cfgMu.RUnlock()
|
||||
|
||||
// Add a rollup sub header if it doesn't already exist.
|
||||
// Otherwise, it must exist already as a rollup on the subject.
|
||||
@@ -555,10 +861,32 @@ func checkMsgHeadersPreClusteredProposal(
|
||||
}
|
||||
}
|
||||
}
|
||||
if scheduleNext := sliceHeader(JSScheduleNext, hdr); len(scheduleNext) > 0 && !sourced {
|
||||
// Clients may only use Nats-Schedule-Next to purge a schedule.
|
||||
if bytesToString(scheduleNext) != JSScheduleNextPurge {
|
||||
apiErr := NewJSMessageSchedulesSchedulerInvalidError()
|
||||
return hdr, msg, 0, apiErr, apiErr
|
||||
}
|
||||
// Nats-Scheduler must accompany the purge and:
|
||||
// - it must NOT be empty.
|
||||
// - it must NOT match the publish subject.
|
||||
if scheduler := sliceHeader(JSScheduler, hdr); len(scheduler) == 0 ||
|
||||
bytesToString(scheduler) == subject || !IsValidPublishSubject(bytesToString(scheduler)) {
|
||||
apiErr := NewJSMessageSchedulesSchedulerInvalidError()
|
||||
return hdr, msg, 0, apiErr, apiErr
|
||||
} else if !allowMsgSchedules {
|
||||
apiErr := NewJSMessageSchedulesDisabledError()
|
||||
return hdr, msg, 0, apiErr, apiErr
|
||||
}
|
||||
} else if !sourced && len(sliceHeader(JSScheduler, hdr)) > 0 {
|
||||
// Clients may only use Nats-Scheduler alongside Nats-Schedule-Next.
|
||||
apiErr := NewJSMessageSchedulesSchedulerInvalidError()
|
||||
return hdr, msg, 0, apiErr, apiErr
|
||||
}
|
||||
|
||||
// Check for any rollups.
|
||||
if rollup := getRollup(hdr); rollup != _EMPTY_ {
|
||||
if !allowRollup || denyPurge {
|
||||
if (!allowRollup || denyPurge) && !sourced {
|
||||
err := errors.New("rollup not permitted")
|
||||
return hdr, msg, 0, NewJSStreamRollupFailedError(err), err
|
||||
}
|
||||
@@ -607,6 +935,19 @@ func checkMsgHeadersPreClusteredProposal(
|
||||
diff.inflight[subject] = i
|
||||
}
|
||||
|
||||
// Subject transform.
|
||||
if subject != rsubject {
|
||||
// The 'subject' is a transformed subject used for consistency checks.
|
||||
// But since we propose the original (raw) subject to our peers, we need
|
||||
// to store the transformed subject separately for when we apply.
|
||||
// TODO(mvv): since subject transforms are handled by each replica individually, this has a
|
||||
// potential for desync given out-of-order stream subject transform updates.
|
||||
if diff.inflightTransform == nil {
|
||||
diff.inflightTransform = make(map[uint64]string, 1)
|
||||
}
|
||||
diff.inflightTransform[mset.clseq] = subject
|
||||
}
|
||||
|
||||
// Check if we have discard new with max msgs or bytes.
|
||||
// We need to deny here otherwise we'd need to bump CLFS, and it could succeed on some
|
||||
// peers and not others depending on consumer ack state (if interest policy).
|
||||
@@ -639,7 +980,8 @@ func checkMsgHeadersPreClusteredProposal(
|
||||
}
|
||||
|
||||
// Similarly, check DiscardNew per-subject threshold to not need to bump CLFS.
|
||||
if discardNewPer && maxMsgsPer > 0 {
|
||||
// Allow rollup messages through since they will purge after storing.
|
||||
if discardNewPer && maxMsgsPer > 0 && len(sliceHeader(JSMsgRollup, hdr)) == 0 {
|
||||
// Get the current total for this subject.
|
||||
totalMsgsForSubject := mset.store.SubjectsTotals(subject)[subject]
|
||||
// Add inflight count in this batch and for this stream.
|
||||
@@ -656,3 +998,68 @@ func checkMsgHeadersPreClusteredProposal(
|
||||
|
||||
return hdr, msg, 0, nil, nil
|
||||
}
|
||||
|
||||
// recalculateClusteredSeq initializes or updates mset.clseq, for example after a leader change.
|
||||
// This is reused for normal clustered publishing into a stream, and for atomic and fast batch publishing.
|
||||
// mset.clMu lock must be held.
|
||||
func recalculateClusteredSeq(mset *stream, needStreamLock bool) (lseq uint64) {
|
||||
// Need to unlock and re-acquire the locks in the proper order.
|
||||
mset.clMu.Unlock()
|
||||
// Locking order is stream -> batchMu -> clMu
|
||||
if needStreamLock {
|
||||
mset.mu.RLock()
|
||||
}
|
||||
batch := mset.batchApply
|
||||
var batchCount uint64
|
||||
if batch != nil {
|
||||
batch.mu.Lock()
|
||||
batchCount = batch.count
|
||||
}
|
||||
mset.clMu.Lock()
|
||||
// Re-capture
|
||||
lseq = mset.lseq
|
||||
mset.clseq = lseq + mset.clfs + batchCount
|
||||
// Keep hold of the mset.clMu, but unlock the others.
|
||||
if batch != nil {
|
||||
batch.mu.Unlock()
|
||||
}
|
||||
if needStreamLock {
|
||||
mset.mu.RUnlock()
|
||||
}
|
||||
return lseq
|
||||
}
|
||||
|
||||
// commitSingleMsg commits and proposes a single message to the node.
|
||||
// This is reused both for normal publishing into a stream, and for fast batch publishing.
|
||||
// mset.clMu lock must be held.
|
||||
func commitSingleMsg(
|
||||
diff *batchStagedDiff, mset *stream, subject string, reply string, hdr []byte, msg []byte, name string,
|
||||
jsa *jsAccount, mt *msgTrace, node RaftNode, replicas int, lseq uint64,
|
||||
) error {
|
||||
// Do proposal.
|
||||
esm := encodeStreamMsgAllowCompress(subject, reply, hdr, msg, mset.clseq, time.Now().UnixNano(), false)
|
||||
if err := node.Propose(esm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var mtKey uint64
|
||||
if mt != nil {
|
||||
mtKey = mset.clseq
|
||||
if mset.mt == nil {
|
||||
mset.mt = make(map[uint64]*msgTrace)
|
||||
}
|
||||
mset.mt[mtKey] = mt
|
||||
}
|
||||
|
||||
diff.commit(mset)
|
||||
mset.clseq++
|
||||
mset.trackReplicationTraffic(node, len(esm), replicas)
|
||||
|
||||
// Check to see if we are being overrun.
|
||||
// TODO(dlc) - Make this a limit where we drop messages to protect ourselves, but allow to be configured.
|
||||
if mset.clseq-(lseq+mset.clfs) > streamLagWarnThreshold {
|
||||
lerr := fmt.Errorf("JetStream stream '%s > %s' has high message lag", jsa.acc().Name, name)
|
||||
mset.srv.RateLimitWarnf("%s", lerr.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
+640
-295
File diff suppressed because it is too large
Load Diff
+300
@@ -29,12 +29,30 @@ const (
|
||||
// JSAtomicPublishTooLargeBatchErrF atomic publish batch is too large: {size}
|
||||
JSAtomicPublishTooLargeBatchErrF ErrorIdentifier = 10199
|
||||
|
||||
// JSAtomicPublishTooManyInflight atomic publish too many inflight
|
||||
JSAtomicPublishTooManyInflight ErrorIdentifier = 10210
|
||||
|
||||
// JSAtomicPublishUnsupportedHeaderBatchErr atomic publish unsupported header used: {header}
|
||||
JSAtomicPublishUnsupportedHeaderBatchErr ErrorIdentifier = 10177
|
||||
|
||||
// JSBadRequestErr bad request
|
||||
JSBadRequestErr ErrorIdentifier = 10003
|
||||
|
||||
// JSBatchPublishDisabledErr batch publish is disabled
|
||||
JSBatchPublishDisabledErr ErrorIdentifier = 10205
|
||||
|
||||
// JSBatchPublishInvalidBatchIDErr batch publish ID is invalid
|
||||
JSBatchPublishInvalidBatchIDErr ErrorIdentifier = 10207
|
||||
|
||||
// JSBatchPublishInvalidPatternErr batch publish pattern is invalid
|
||||
JSBatchPublishInvalidPatternErr ErrorIdentifier = 10206
|
||||
|
||||
// JSBatchPublishTooManyInflight batch publish too many inflight
|
||||
JSBatchPublishTooManyInflight ErrorIdentifier = 10211
|
||||
|
||||
// JSBatchPublishUnknownBatchIDErr batch publish ID unknown
|
||||
JSBatchPublishUnknownBatchIDErr ErrorIdentifier = 10208
|
||||
|
||||
// JSClusterIncompleteErr incomplete results
|
||||
JSClusterIncompleteErr ErrorIdentifier = 10004
|
||||
|
||||
@@ -71,6 +89,21 @@ const (
|
||||
// JSClusterUnSupportFeatureErr not currently supported in clustered mode
|
||||
JSClusterUnSupportFeatureErr ErrorIdentifier = 10036
|
||||
|
||||
// JSConsumerAckFCRequiresFCErr flow control ack policy requires flow control
|
||||
JSConsumerAckFCRequiresFCErr ErrorIdentifier = 10219
|
||||
|
||||
// JSConsumerAckFCRequiresMaxAckPendingErr flow control ack policy requires max ack pending
|
||||
JSConsumerAckFCRequiresMaxAckPendingErr ErrorIdentifier = 10220
|
||||
|
||||
// JSConsumerAckFCRequiresNoAckWaitErr flow control ack policy requires unset ack wait
|
||||
JSConsumerAckFCRequiresNoAckWaitErr ErrorIdentifier = 10221
|
||||
|
||||
// JSConsumerAckFCRequiresNoMaxDeliverErr flow control ack policy requires unset max deliver
|
||||
JSConsumerAckFCRequiresNoMaxDeliverErr ErrorIdentifier = 10222
|
||||
|
||||
// JSConsumerAckFCRequiresPushErr flow control ack policy requires a push based consumer
|
||||
JSConsumerAckFCRequiresPushErr ErrorIdentifier = 10218
|
||||
|
||||
// JSConsumerAckPolicyInvalidErr consumer ack policy invalid
|
||||
JSConsumerAckPolicyInvalidErr ErrorIdentifier = 10181
|
||||
|
||||
@@ -167,6 +200,9 @@ const (
|
||||
// JSConsumerInvalidPriorityGroupErr Provided priority group does not exist for this consumer
|
||||
JSConsumerInvalidPriorityGroupErr ErrorIdentifier = 10160
|
||||
|
||||
// JSConsumerInvalidResetErr invalid reset: {err}
|
||||
JSConsumerInvalidResetErr ErrorIdentifier = 10204
|
||||
|
||||
// JSConsumerInvalidSamplingErrF failed to parse consumer sampling configuration: {err}
|
||||
JSConsumerInvalidSamplingErrF ErrorIdentifier = 10095
|
||||
|
||||
@@ -317,21 +353,36 @@ const (
|
||||
// JSMessageSchedulesRollupInvalidErr message schedules invalid rollup
|
||||
JSMessageSchedulesRollupInvalidErr ErrorIdentifier = 10192
|
||||
|
||||
// JSMessageSchedulesSchedulerInvalidErr message schedules invalid scheduler
|
||||
JSMessageSchedulesSchedulerInvalidErr ErrorIdentifier = 10212
|
||||
|
||||
// JSMessageSchedulesSourceInvalidErr message schedules source is invalid
|
||||
JSMessageSchedulesSourceInvalidErr ErrorIdentifier = 10203
|
||||
|
||||
// JSMessageSchedulesTTLInvalidErr message schedules invalid per-message TTL
|
||||
JSMessageSchedulesTTLInvalidErr ErrorIdentifier = 10191
|
||||
|
||||
// JSMessageSchedulesTargetInvalidErr message schedules target is invalid
|
||||
JSMessageSchedulesTargetInvalidErr ErrorIdentifier = 10190
|
||||
|
||||
// JSMessageSchedulesTimeZoneInvalidErr message schedules time zone is invalid
|
||||
JSMessageSchedulesTimeZoneInvalidErr ErrorIdentifier = 10223
|
||||
|
||||
// JSMessageTTLDisabledErr per-message TTL is disabled
|
||||
JSMessageTTLDisabledErr ErrorIdentifier = 10166
|
||||
|
||||
// JSMessageTTLInvalidErr invalid per-message TTL
|
||||
JSMessageTTLInvalidErr ErrorIdentifier = 10165
|
||||
|
||||
// JSMirrorConsumerRequiresAckFCErr stream mirror consumer requires flow control ack policy
|
||||
JSMirrorConsumerRequiresAckFCErr ErrorIdentifier = 10214
|
||||
|
||||
// JSMirrorConsumerSetupFailedErrF generic mirror consumer setup failure string ({err})
|
||||
JSMirrorConsumerSetupFailedErrF ErrorIdentifier = 10029
|
||||
|
||||
// JSMirrorDurableConsumerCfgInvalid stream mirror consumer config is invalid
|
||||
JSMirrorDurableConsumerCfgInvalid ErrorIdentifier = 10213
|
||||
|
||||
// JSMirrorInvalidStreamName mirrored stream name is invalid
|
||||
JSMirrorInvalidStreamName ErrorIdentifier = 10142
|
||||
|
||||
@@ -353,6 +404,9 @@ const (
|
||||
// JSMirrorWithAtomicPublishErr stream mirrors can not also use atomic publishing
|
||||
JSMirrorWithAtomicPublishErr ErrorIdentifier = 10198
|
||||
|
||||
// JSMirrorWithBatchPublishErr stream mirrors can not also use batch publishing
|
||||
JSMirrorWithBatchPublishErr ErrorIdentifier = 10209
|
||||
|
||||
// JSMirrorWithCountersErr stream mirrors can not also calculate counters
|
||||
JSMirrorWithCountersErr ErrorIdentifier = 10173
|
||||
|
||||
@@ -416,12 +470,21 @@ const (
|
||||
// JSSnapshotDeliverSubjectInvalidErr deliver subject not valid
|
||||
JSSnapshotDeliverSubjectInvalidErr ErrorIdentifier = 10015
|
||||
|
||||
// JSSourceConsumerRequiresAckFCErr stream source consumer requires flow control ack policy
|
||||
JSSourceConsumerRequiresAckFCErr ErrorIdentifier = 10217
|
||||
|
||||
// JSSourceConsumerSetupFailedErrF General source consumer setup failure string ({err})
|
||||
JSSourceConsumerSetupFailedErrF ErrorIdentifier = 10045
|
||||
|
||||
// JSSourceDuplicateDetected source stream, filter and transform (plus external if present) must form a unique combination (duplicate source configuration detected)
|
||||
JSSourceDuplicateDetected ErrorIdentifier = 10140
|
||||
|
||||
// JSSourceDurableConsumerCfgInvalid stream source consumer config is invalid
|
||||
JSSourceDurableConsumerCfgInvalid ErrorIdentifier = 10215
|
||||
|
||||
// JSSourceDurableConsumerDuplicateDetected duplicate stream source consumer detected
|
||||
JSSourceDurableConsumerDuplicateDetected ErrorIdentifier = 10216
|
||||
|
||||
// JSSourceInvalidStreamName sourced stream name is invalid
|
||||
JSSourceInvalidStreamName ErrorIdentifier = 10141
|
||||
|
||||
@@ -619,8 +682,14 @@ var (
|
||||
JSAtomicPublishInvalidBatchIDErr: {Code: 400, ErrCode: 10179, Description: "atomic publish batch ID is invalid"},
|
||||
JSAtomicPublishMissingSeqErr: {Code: 400, ErrCode: 10175, Description: "atomic publish sequence is missing"},
|
||||
JSAtomicPublishTooLargeBatchErrF: {Code: 400, ErrCode: 10199, Description: "atomic publish batch is too large: {size}"},
|
||||
JSAtomicPublishTooManyInflight: {Code: 429, ErrCode: 10210, Description: "atomic publish too many inflight"},
|
||||
JSAtomicPublishUnsupportedHeaderBatchErr: {Code: 400, ErrCode: 10177, Description: "atomic publish unsupported header used: {header}"},
|
||||
JSBadRequestErr: {Code: 400, ErrCode: 10003, Description: "bad request"},
|
||||
JSBatchPublishDisabledErr: {Code: 400, ErrCode: 10205, Description: "batch publish is disabled"},
|
||||
JSBatchPublishInvalidBatchIDErr: {Code: 400, ErrCode: 10207, Description: "batch publish ID is invalid"},
|
||||
JSBatchPublishInvalidPatternErr: {Code: 400, ErrCode: 10206, Description: "batch publish pattern is invalid"},
|
||||
JSBatchPublishTooManyInflight: {Code: 429, ErrCode: 10211, Description: "batch publish too many inflight"},
|
||||
JSBatchPublishUnknownBatchIDErr: {Code: 400, ErrCode: 10208, Description: "batch publish ID unknown"},
|
||||
JSClusterIncompleteErr: {Code: 503, ErrCode: 10004, Description: "incomplete results"},
|
||||
JSClusterNoPeersErrF: {Code: 400, ErrCode: 10005, Description: "{err}"},
|
||||
JSClusterNotActiveErr: {Code: 500, ErrCode: 10006, Description: "JetStream not in clustered mode"},
|
||||
@@ -633,6 +702,11 @@ var (
|
||||
JSClusterServerNotMemberErr: {Code: 400, ErrCode: 10044, Description: "server is not a member of the cluster"},
|
||||
JSClusterTagsErr: {Code: 400, ErrCode: 10011, Description: "tags placement not supported for operation"},
|
||||
JSClusterUnSupportFeatureErr: {Code: 503, ErrCode: 10036, Description: "not currently supported in clustered mode"},
|
||||
JSConsumerAckFCRequiresFCErr: {Code: 400, ErrCode: 10219, Description: "flow control ack policy requires flow control"},
|
||||
JSConsumerAckFCRequiresMaxAckPendingErr: {Code: 400, ErrCode: 10220, Description: "flow control ack policy requires max ack pending"},
|
||||
JSConsumerAckFCRequiresNoAckWaitErr: {Code: 400, ErrCode: 10221, Description: "flow control ack policy requires unset ack wait"},
|
||||
JSConsumerAckFCRequiresNoMaxDeliverErr: {Code: 400, ErrCode: 10222, Description: "flow control ack policy requires unset max deliver"},
|
||||
JSConsumerAckFCRequiresPushErr: {Code: 400, ErrCode: 10218, Description: "flow control ack policy requires a push based consumer"},
|
||||
JSConsumerAckPolicyInvalidErr: {Code: 400, ErrCode: 10181, Description: "consumer ack policy invalid"},
|
||||
JSConsumerAckWaitNegativeErr: {Code: 400, ErrCode: 10183, Description: "consumer ack wait needs to be positive"},
|
||||
JSConsumerAlreadyExists: {Code: 400, ErrCode: 10148, Description: "consumer already exists"},
|
||||
@@ -665,6 +739,7 @@ var (
|
||||
JSConsumerInvalidGroupNameErr: {Code: 400, ErrCode: 10162, Description: "Valid priority group name must match A-Z, a-z, 0-9, -_/=)+ and may not exceed 16 characters"},
|
||||
JSConsumerInvalidPolicyErrF: {Code: 400, ErrCode: 10094, Description: "{err}"},
|
||||
JSConsumerInvalidPriorityGroupErr: {Code: 400, ErrCode: 10160, Description: "Provided priority group does not exist for this consumer"},
|
||||
JSConsumerInvalidResetErr: {Code: 400, ErrCode: 10204, Description: "invalid reset: {err}"},
|
||||
JSConsumerInvalidSamplingErrF: {Code: 400, ErrCode: 10095, Description: "failed to parse consumer sampling configuration: {err}"},
|
||||
JSConsumerMaxDeliverBackoffErr: {Code: 400, ErrCode: 10116, Description: "max deliver is required to be > length of backoff values"},
|
||||
JSConsumerMaxPendingAckExcessErrF: {Code: 400, ErrCode: 10121, Description: "consumer max ack pending exceeds system limit of {limit}"},
|
||||
@@ -715,11 +790,16 @@ var (
|
||||
JSMessageSchedulesDisabledErr: {Code: 400, ErrCode: 10188, Description: "message schedules is disabled"},
|
||||
JSMessageSchedulesPatternInvalidErr: {Code: 400, ErrCode: 10189, Description: "message schedules pattern is invalid"},
|
||||
JSMessageSchedulesRollupInvalidErr: {Code: 400, ErrCode: 10192, Description: "message schedules invalid rollup"},
|
||||
JSMessageSchedulesSchedulerInvalidErr: {Code: 400, ErrCode: 10212, Description: "message schedules invalid scheduler"},
|
||||
JSMessageSchedulesSourceInvalidErr: {Code: 400, ErrCode: 10203, Description: "message schedules source is invalid"},
|
||||
JSMessageSchedulesTTLInvalidErr: {Code: 400, ErrCode: 10191, Description: "message schedules invalid per-message TTL"},
|
||||
JSMessageSchedulesTargetInvalidErr: {Code: 400, ErrCode: 10190, Description: "message schedules target is invalid"},
|
||||
JSMessageSchedulesTimeZoneInvalidErr: {Code: 400, ErrCode: 10223, Description: "message schedules time zone is invalid"},
|
||||
JSMessageTTLDisabledErr: {Code: 400, ErrCode: 10166, Description: "per-message TTL is disabled"},
|
||||
JSMessageTTLInvalidErr: {Code: 400, ErrCode: 10165, Description: "invalid per-message TTL"},
|
||||
JSMirrorConsumerRequiresAckFCErr: {Code: 400, ErrCode: 10214, Description: "stream mirror consumer requires flow control ack policy"},
|
||||
JSMirrorConsumerSetupFailedErrF: {Code: 500, ErrCode: 10029, Description: "{err}"},
|
||||
JSMirrorDurableConsumerCfgInvalid: {Code: 400, ErrCode: 10213, Description: "stream mirror consumer config is invalid"},
|
||||
JSMirrorInvalidStreamName: {Code: 400, ErrCode: 10142, Description: "mirrored stream name is invalid"},
|
||||
JSMirrorInvalidSubjectFilter: {Code: 400, ErrCode: 10151, Description: "mirror transform source: {err}"},
|
||||
JSMirrorInvalidTransformDestination: {Code: 400, ErrCode: 10154, Description: "mirror transform: {err}"},
|
||||
@@ -727,6 +807,7 @@ var (
|
||||
JSMirrorMultipleFiltersNotAllowed: {Code: 400, ErrCode: 10150, Description: "mirror with multiple subject transforms cannot also have a single subject filter"},
|
||||
JSMirrorOverlappingSubjectFilters: {Code: 400, ErrCode: 10152, Description: "mirror subject filters can not overlap"},
|
||||
JSMirrorWithAtomicPublishErr: {Code: 400, ErrCode: 10198, Description: "stream mirrors can not also use atomic publishing"},
|
||||
JSMirrorWithBatchPublishErr: {Code: 400, ErrCode: 10209, Description: "stream mirrors can not also use batch publishing"},
|
||||
JSMirrorWithCountersErr: {Code: 400, ErrCode: 10173, Description: "stream mirrors can not also calculate counters"},
|
||||
JSMirrorWithFirstSeqErr: {Code: 400, ErrCode: 10143, Description: "stream mirrors can not have first sequence configured"},
|
||||
JSMirrorWithMsgSchedulesErr: {Code: 400, ErrCode: 10186, Description: "stream mirrors can not also schedule messages"},
|
||||
@@ -748,8 +829,11 @@ var (
|
||||
JSRestoreSubscribeFailedErrF: {Code: 500, ErrCode: 10042, Description: "JetStream unable to subscribe to restore snapshot {subject}: {err}"},
|
||||
JSSequenceNotFoundErrF: {Code: 400, ErrCode: 10043, Description: "sequence {seq} not found"},
|
||||
JSSnapshotDeliverSubjectInvalidErr: {Code: 400, ErrCode: 10015, Description: "deliver subject not valid"},
|
||||
JSSourceConsumerRequiresAckFCErr: {Code: 400, ErrCode: 10217, Description: "stream source consumer requires flow control ack policy"},
|
||||
JSSourceConsumerSetupFailedErrF: {Code: 500, ErrCode: 10045, Description: "{err}"},
|
||||
JSSourceDuplicateDetected: {Code: 400, ErrCode: 10140, Description: "duplicate source configuration detected"},
|
||||
JSSourceDurableConsumerCfgInvalid: {Code: 400, ErrCode: 10215, Description: "stream source consumer config is invalid"},
|
||||
JSSourceDurableConsumerDuplicateDetected: {Code: 400, ErrCode: 10216, Description: "duplicate stream source consumer detected"},
|
||||
JSSourceInvalidStreamName: {Code: 400, ErrCode: 10141, Description: "sourced stream name is invalid"},
|
||||
JSSourceInvalidSubjectFilter: {Code: 400, ErrCode: 10145, Description: "source transform source: {err}"},
|
||||
JSSourceInvalidTransformDestination: {Code: 400, ErrCode: 10146, Description: "source transform: {err}"},
|
||||
@@ -923,6 +1007,16 @@ func NewJSAtomicPublishTooLargeBatchError(size interface{}, opts ...ErrorOption)
|
||||
}
|
||||
}
|
||||
|
||||
// NewJSAtomicPublishTooManyInflightError creates a new JSAtomicPublishTooManyInflight error: "atomic publish too many inflight"
|
||||
func NewJSAtomicPublishTooManyInflightError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
if ae, ok := eopts.err.(*ApiError); ok {
|
||||
return ae
|
||||
}
|
||||
|
||||
return ApiErrors[JSAtomicPublishTooManyInflight]
|
||||
}
|
||||
|
||||
// NewJSAtomicPublishUnsupportedHeaderBatchError creates a new JSAtomicPublishUnsupportedHeaderBatchErr error: "atomic publish unsupported header used: {header}"
|
||||
func NewJSAtomicPublishUnsupportedHeaderBatchError(header interface{}, opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
@@ -949,6 +1043,56 @@ func NewJSBadRequestError(opts ...ErrorOption) *ApiError {
|
||||
return ApiErrors[JSBadRequestErr]
|
||||
}
|
||||
|
||||
// NewJSBatchPublishDisabledError creates a new JSBatchPublishDisabledErr error: "batch publish is disabled"
|
||||
func NewJSBatchPublishDisabledError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
if ae, ok := eopts.err.(*ApiError); ok {
|
||||
return ae
|
||||
}
|
||||
|
||||
return ApiErrors[JSBatchPublishDisabledErr]
|
||||
}
|
||||
|
||||
// NewJSBatchPublishInvalidBatchIDError creates a new JSBatchPublishInvalidBatchIDErr error: "batch publish ID is invalid"
|
||||
func NewJSBatchPublishInvalidBatchIDError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
if ae, ok := eopts.err.(*ApiError); ok {
|
||||
return ae
|
||||
}
|
||||
|
||||
return ApiErrors[JSBatchPublishInvalidBatchIDErr]
|
||||
}
|
||||
|
||||
// NewJSBatchPublishInvalidPatternError creates a new JSBatchPublishInvalidPatternErr error: "batch publish pattern is invalid"
|
||||
func NewJSBatchPublishInvalidPatternError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
if ae, ok := eopts.err.(*ApiError); ok {
|
||||
return ae
|
||||
}
|
||||
|
||||
return ApiErrors[JSBatchPublishInvalidPatternErr]
|
||||
}
|
||||
|
||||
// NewJSBatchPublishTooManyInflightError creates a new JSBatchPublishTooManyInflight error: "batch publish too many inflight"
|
||||
func NewJSBatchPublishTooManyInflightError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
if ae, ok := eopts.err.(*ApiError); ok {
|
||||
return ae
|
||||
}
|
||||
|
||||
return ApiErrors[JSBatchPublishTooManyInflight]
|
||||
}
|
||||
|
||||
// NewJSBatchPublishUnknownBatchIDError creates a new JSBatchPublishUnknownBatchIDErr error: "batch publish ID unknown"
|
||||
func NewJSBatchPublishUnknownBatchIDError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
if ae, ok := eopts.err.(*ApiError); ok {
|
||||
return ae
|
||||
}
|
||||
|
||||
return ApiErrors[JSBatchPublishUnknownBatchIDErr]
|
||||
}
|
||||
|
||||
// NewJSClusterIncompleteError creates a new JSClusterIncompleteErr error: "incomplete results"
|
||||
func NewJSClusterIncompleteError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
@@ -1075,6 +1219,56 @@ func NewJSClusterUnSupportFeatureError(opts ...ErrorOption) *ApiError {
|
||||
return ApiErrors[JSClusterUnSupportFeatureErr]
|
||||
}
|
||||
|
||||
// NewJSConsumerAckFCRequiresFCError creates a new JSConsumerAckFCRequiresFCErr error: "flow control ack policy requires flow control"
|
||||
func NewJSConsumerAckFCRequiresFCError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
if ae, ok := eopts.err.(*ApiError); ok {
|
||||
return ae
|
||||
}
|
||||
|
||||
return ApiErrors[JSConsumerAckFCRequiresFCErr]
|
||||
}
|
||||
|
||||
// NewJSConsumerAckFCRequiresMaxAckPendingError creates a new JSConsumerAckFCRequiresMaxAckPendingErr error: "flow control ack policy requires max ack pending"
|
||||
func NewJSConsumerAckFCRequiresMaxAckPendingError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
if ae, ok := eopts.err.(*ApiError); ok {
|
||||
return ae
|
||||
}
|
||||
|
||||
return ApiErrors[JSConsumerAckFCRequiresMaxAckPendingErr]
|
||||
}
|
||||
|
||||
// NewJSConsumerAckFCRequiresNoAckWaitError creates a new JSConsumerAckFCRequiresNoAckWaitErr error: "flow control ack policy requires unset ack wait"
|
||||
func NewJSConsumerAckFCRequiresNoAckWaitError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
if ae, ok := eopts.err.(*ApiError); ok {
|
||||
return ae
|
||||
}
|
||||
|
||||
return ApiErrors[JSConsumerAckFCRequiresNoAckWaitErr]
|
||||
}
|
||||
|
||||
// NewJSConsumerAckFCRequiresNoMaxDeliverError creates a new JSConsumerAckFCRequiresNoMaxDeliverErr error: "flow control ack policy requires unset max deliver"
|
||||
func NewJSConsumerAckFCRequiresNoMaxDeliverError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
if ae, ok := eopts.err.(*ApiError); ok {
|
||||
return ae
|
||||
}
|
||||
|
||||
return ApiErrors[JSConsumerAckFCRequiresNoMaxDeliverErr]
|
||||
}
|
||||
|
||||
// NewJSConsumerAckFCRequiresPushError creates a new JSConsumerAckFCRequiresPushErr error: "flow control ack policy requires a push based consumer"
|
||||
func NewJSConsumerAckFCRequiresPushError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
if ae, ok := eopts.err.(*ApiError); ok {
|
||||
return ae
|
||||
}
|
||||
|
||||
return ApiErrors[JSConsumerAckFCRequiresPushErr]
|
||||
}
|
||||
|
||||
// NewJSConsumerAckPolicyInvalidError creates a new JSConsumerAckPolicyInvalidErr error: "consumer ack policy invalid"
|
||||
func NewJSConsumerAckPolicyInvalidError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
@@ -1419,6 +1613,22 @@ func NewJSConsumerInvalidPriorityGroupError(opts ...ErrorOption) *ApiError {
|
||||
return ApiErrors[JSConsumerInvalidPriorityGroupErr]
|
||||
}
|
||||
|
||||
// NewJSConsumerInvalidResetError creates a new JSConsumerInvalidResetErr error: "invalid reset: {err}"
|
||||
func NewJSConsumerInvalidResetError(err error, opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
if ae, ok := eopts.err.(*ApiError); ok {
|
||||
return ae
|
||||
}
|
||||
|
||||
e := ApiErrors[JSConsumerInvalidResetErr]
|
||||
args := e.toReplacerArgs([]interface{}{"{err}", err})
|
||||
return &ApiError{
|
||||
Code: e.Code,
|
||||
ErrCode: e.ErrCode,
|
||||
Description: strings.NewReplacer(args...).Replace(e.Description),
|
||||
}
|
||||
}
|
||||
|
||||
// NewJSConsumerInvalidSamplingError creates a new JSConsumerInvalidSamplingErrF error: "failed to parse consumer sampling configuration: {err}"
|
||||
func NewJSConsumerInvalidSamplingError(err error, opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
@@ -1967,6 +2177,26 @@ func NewJSMessageSchedulesRollupInvalidError(opts ...ErrorOption) *ApiError {
|
||||
return ApiErrors[JSMessageSchedulesRollupInvalidErr]
|
||||
}
|
||||
|
||||
// NewJSMessageSchedulesSchedulerInvalidError creates a new JSMessageSchedulesSchedulerInvalidErr error: "message schedules invalid scheduler"
|
||||
func NewJSMessageSchedulesSchedulerInvalidError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
if ae, ok := eopts.err.(*ApiError); ok {
|
||||
return ae
|
||||
}
|
||||
|
||||
return ApiErrors[JSMessageSchedulesSchedulerInvalidErr]
|
||||
}
|
||||
|
||||
// NewJSMessageSchedulesSourceInvalidError creates a new JSMessageSchedulesSourceInvalidErr error: "message schedules source is invalid"
|
||||
func NewJSMessageSchedulesSourceInvalidError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
if ae, ok := eopts.err.(*ApiError); ok {
|
||||
return ae
|
||||
}
|
||||
|
||||
return ApiErrors[JSMessageSchedulesSourceInvalidErr]
|
||||
}
|
||||
|
||||
// NewJSMessageSchedulesTTLInvalidError creates a new JSMessageSchedulesTTLInvalidErr error: "message schedules invalid per-message TTL"
|
||||
func NewJSMessageSchedulesTTLInvalidError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
@@ -1987,6 +2217,16 @@ func NewJSMessageSchedulesTargetInvalidError(opts ...ErrorOption) *ApiError {
|
||||
return ApiErrors[JSMessageSchedulesTargetInvalidErr]
|
||||
}
|
||||
|
||||
// NewJSMessageSchedulesTimeZoneInvalidError creates a new JSMessageSchedulesTimeZoneInvalidErr error: "message schedules time zone is invalid"
|
||||
func NewJSMessageSchedulesTimeZoneInvalidError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
if ae, ok := eopts.err.(*ApiError); ok {
|
||||
return ae
|
||||
}
|
||||
|
||||
return ApiErrors[JSMessageSchedulesTimeZoneInvalidErr]
|
||||
}
|
||||
|
||||
// NewJSMessageTTLDisabledError creates a new JSMessageTTLDisabledErr error: "per-message TTL is disabled"
|
||||
func NewJSMessageTTLDisabledError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
@@ -2007,6 +2247,16 @@ func NewJSMessageTTLInvalidError(opts ...ErrorOption) *ApiError {
|
||||
return ApiErrors[JSMessageTTLInvalidErr]
|
||||
}
|
||||
|
||||
// NewJSMirrorConsumerRequiresAckFCError creates a new JSMirrorConsumerRequiresAckFCErr error: "stream mirror consumer requires flow control ack policy"
|
||||
func NewJSMirrorConsumerRequiresAckFCError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
if ae, ok := eopts.err.(*ApiError); ok {
|
||||
return ae
|
||||
}
|
||||
|
||||
return ApiErrors[JSMirrorConsumerRequiresAckFCErr]
|
||||
}
|
||||
|
||||
// NewJSMirrorConsumerSetupFailedError creates a new JSMirrorConsumerSetupFailedErrF error: "{err}"
|
||||
func NewJSMirrorConsumerSetupFailedError(err error, opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
@@ -2023,6 +2273,16 @@ func NewJSMirrorConsumerSetupFailedError(err error, opts ...ErrorOption) *ApiErr
|
||||
}
|
||||
}
|
||||
|
||||
// NewJSMirrorDurableConsumerCfgInvalidError creates a new JSMirrorDurableConsumerCfgInvalid error: "stream mirror consumer config is invalid"
|
||||
func NewJSMirrorDurableConsumerCfgInvalidError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
if ae, ok := eopts.err.(*ApiError); ok {
|
||||
return ae
|
||||
}
|
||||
|
||||
return ApiErrors[JSMirrorDurableConsumerCfgInvalid]
|
||||
}
|
||||
|
||||
// NewJSMirrorInvalidStreamNameError creates a new JSMirrorInvalidStreamName error: "mirrored stream name is invalid"
|
||||
func NewJSMirrorInvalidStreamNameError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
@@ -2105,6 +2365,16 @@ func NewJSMirrorWithAtomicPublishError(opts ...ErrorOption) *ApiError {
|
||||
return ApiErrors[JSMirrorWithAtomicPublishErr]
|
||||
}
|
||||
|
||||
// NewJSMirrorWithBatchPublishError creates a new JSMirrorWithBatchPublishErr error: "stream mirrors can not also use batch publishing"
|
||||
func NewJSMirrorWithBatchPublishError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
if ae, ok := eopts.err.(*ApiError); ok {
|
||||
return ae
|
||||
}
|
||||
|
||||
return ApiErrors[JSMirrorWithBatchPublishErr]
|
||||
}
|
||||
|
||||
// NewJSMirrorWithCountersError creates a new JSMirrorWithCountersErr error: "stream mirrors can not also calculate counters"
|
||||
func NewJSMirrorWithCountersError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
@@ -2339,6 +2609,16 @@ func NewJSSnapshotDeliverSubjectInvalidError(opts ...ErrorOption) *ApiError {
|
||||
return ApiErrors[JSSnapshotDeliverSubjectInvalidErr]
|
||||
}
|
||||
|
||||
// NewJSSourceConsumerRequiresAckFCError creates a new JSSourceConsumerRequiresAckFCErr error: "stream source consumer requires flow control ack policy"
|
||||
func NewJSSourceConsumerRequiresAckFCError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
if ae, ok := eopts.err.(*ApiError); ok {
|
||||
return ae
|
||||
}
|
||||
|
||||
return ApiErrors[JSSourceConsumerRequiresAckFCErr]
|
||||
}
|
||||
|
||||
// NewJSSourceConsumerSetupFailedError creates a new JSSourceConsumerSetupFailedErrF error: "{err}"
|
||||
func NewJSSourceConsumerSetupFailedError(err error, opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
@@ -2365,6 +2645,26 @@ func NewJSSourceDuplicateDetectedError(opts ...ErrorOption) *ApiError {
|
||||
return ApiErrors[JSSourceDuplicateDetected]
|
||||
}
|
||||
|
||||
// NewJSSourceDurableConsumerCfgInvalidError creates a new JSSourceDurableConsumerCfgInvalid error: "stream source consumer config is invalid"
|
||||
func NewJSSourceDurableConsumerCfgInvalidError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
if ae, ok := eopts.err.(*ApiError); ok {
|
||||
return ae
|
||||
}
|
||||
|
||||
return ApiErrors[JSSourceDurableConsumerCfgInvalid]
|
||||
}
|
||||
|
||||
// NewJSSourceDurableConsumerDuplicateDetectedError creates a new JSSourceDurableConsumerDuplicateDetected error: "duplicate stream source consumer detected"
|
||||
func NewJSSourceDurableConsumerDuplicateDetectedError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
if ae, ok := eopts.err.(*ApiError); ok {
|
||||
return ae
|
||||
}
|
||||
|
||||
return ApiErrors[JSSourceDurableConsumerDuplicateDetected]
|
||||
}
|
||||
|
||||
// NewJSSourceInvalidStreamNameError creates a new JSSourceInvalidStreamName error: "sourced stream name is invalid"
|
||||
func NewJSSourceInvalidStreamNameError(opts ...ErrorOption) *ApiError {
|
||||
eopts := parseOpts(opts)
|
||||
|
||||
+7
-7
@@ -71,10 +71,9 @@ const (
|
||||
// JSStreamActionAdvisory indicates that a stream was created, edited or deleted
|
||||
type JSStreamActionAdvisory struct {
|
||||
TypedEvent
|
||||
Stream string `json:"stream"`
|
||||
Action ActionAdvisoryType `json:"action"`
|
||||
Template string `json:"template,omitempty"` // Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Stream string `json:"stream"`
|
||||
Action ActionAdvisoryType `json:"action"`
|
||||
Domain string `json:"domain,omitempty"`
|
||||
}
|
||||
|
||||
const JSStreamActionAdvisoryType = "io.nats.jetstream.advisory.v1.stream_action"
|
||||
@@ -269,9 +268,10 @@ type JSStreamBatchAbandonedAdvisory struct {
|
||||
type BatchAbandonReason string
|
||||
|
||||
var (
|
||||
BatchTimeout BatchAbandonReason = "timeout"
|
||||
BatchLarge BatchAbandonReason = "large"
|
||||
BatchIncomplete BatchAbandonReason = "incomplete"
|
||||
BatchTimeout BatchAbandonReason = "timeout"
|
||||
BatchLarge BatchAbandonReason = "large"
|
||||
BatchIncomplete BatchAbandonReason = "incomplete"
|
||||
BatchRequirementsNotMet BatchAbandonReason = "unsupported"
|
||||
)
|
||||
|
||||
// JSConsumerLeaderElectedAdvisoryType is sent when the system elects a leader for a consumer.
|
||||
|
||||
+11
-1
@@ -17,7 +17,7 @@ import "strconv"
|
||||
|
||||
const (
|
||||
// JSApiLevel is the maximum supported JetStream API level for this server.
|
||||
JSApiLevel int = 3
|
||||
JSApiLevel int = 4
|
||||
|
||||
JSRequiredLevelMetadataKey = "_nats.req.level"
|
||||
JSServerVersionMetadataKey = "_nats.ver"
|
||||
@@ -82,6 +82,11 @@ func setStaticStreamMetadata(cfg *StreamConfig) {
|
||||
requires(2)
|
||||
}
|
||||
|
||||
// Fast batch publishing was added in v2.14 and requires API level 4.
|
||||
if cfg.AllowBatchPublish {
|
||||
requires(4)
|
||||
}
|
||||
|
||||
cfg.Metadata[JSRequiredLevelMetadataKey] = strconv.Itoa(requiredApiLevel)
|
||||
}
|
||||
|
||||
@@ -158,6 +163,11 @@ func setStaticConsumerMetadata(cfg *ConsumerConfig) {
|
||||
requires(1)
|
||||
}
|
||||
|
||||
// Added in 2.14
|
||||
if cfg.AckPolicy == AckFlowControl {
|
||||
requires(4)
|
||||
}
|
||||
|
||||
cfg.Metadata[JSRequiredLevelMetadataKey] = strconv.Itoa(requiredApiLevel)
|
||||
}
|
||||
|
||||
|
||||
+44
-14
@@ -202,12 +202,15 @@ func validateSrc(claims *jwt.UserClaims, host string) bool {
|
||||
}
|
||||
|
||||
func validateTimes(claims *jwt.UserClaims) (bool, time.Duration) {
|
||||
return validateTimesAt(claims, time.Now())
|
||||
}
|
||||
|
||||
func validateTimesAt(claims *jwt.UserClaims, now time.Time) (bool, time.Duration) {
|
||||
if claims == nil {
|
||||
return false, time.Duration(0)
|
||||
} else if len(claims.Times) == 0 {
|
||||
return true, time.Duration(0)
|
||||
}
|
||||
now := time.Now()
|
||||
loc := time.Local
|
||||
if claims.Locale != "" {
|
||||
var err error
|
||||
@@ -216,10 +219,11 @@ func validateTimes(claims *jwt.UserClaims) (bool, time.Duration) {
|
||||
}
|
||||
now = now.In(loc)
|
||||
}
|
||||
|
||||
var ok bool
|
||||
var validFor time.Duration
|
||||
|
||||
for _, timeRange := range claims.Times {
|
||||
y, m, d := now.Date()
|
||||
m = m - 1
|
||||
d = d - 1
|
||||
start, err := time.ParseInLocation("15:04:05", timeRange.Start, loc)
|
||||
if err != nil {
|
||||
return false, time.Duration(0) // parsing not expected to fail at this point
|
||||
@@ -228,17 +232,43 @@ func validateTimes(claims *jwt.UserClaims) (bool, time.Duration) {
|
||||
if err != nil {
|
||||
return false, time.Duration(0) // parsing not expected to fail at this point
|
||||
}
|
||||
if start.After(end) {
|
||||
start = start.AddDate(y, int(m), d)
|
||||
d++ // the intent is to be the next day
|
||||
} else {
|
||||
start = start.AddDate(y, int(m), d)
|
||||
|
||||
y, m, d := now.Date()
|
||||
start = time.Date(y, m, d, start.Hour(), start.Minute(), start.Second(), 0, loc)
|
||||
end = time.Date(y, m, d, end.Hour(), end.Minute(), end.Second(), 0, loc)
|
||||
|
||||
inRange, expires := validateTimeRangeAt(start, end, now)
|
||||
if inRange && (!ok || expires > validFor) {
|
||||
ok = true
|
||||
validFor = expires
|
||||
}
|
||||
if start.Before(now) {
|
||||
end = end.AddDate(y, int(m), d)
|
||||
if end.After(now) {
|
||||
return true, end.Sub(now)
|
||||
}
|
||||
}
|
||||
return ok, validFor
|
||||
}
|
||||
|
||||
// Returns true if now is within `start` and `end`, and
|
||||
// how much time is left until `end`.
|
||||
// False if `now` is not within range.
|
||||
func validateTimeRangeAt(start, end, now time.Time) (bool, time.Duration) {
|
||||
// Now falls within range.
|
||||
// For example 11:00-22:00 at 13:00
|
||||
if start.Before(now) && end.After(now) {
|
||||
return true, end.Sub(now)
|
||||
}
|
||||
|
||||
// Range crosses midnight.
|
||||
if start.After(end) {
|
||||
// Now is after midnight.
|
||||
// For example 22:00-06:00 at 05:00.
|
||||
if end.After(now) {
|
||||
return true, end.Sub(now)
|
||||
}
|
||||
|
||||
// Now is before midnight.
|
||||
// For example 22:00-06:00 at 23:30.
|
||||
end = end.AddDate(0, 0, 1)
|
||||
if start.Before(now) && end.After(now) {
|
||||
return true, end.Sub(now)
|
||||
}
|
||||
}
|
||||
return false, time.Duration(0)
|
||||
|
||||
+317
-91
@@ -1,4 +1,4 @@
|
||||
// Copyright 2019-2025 The NATS Authors
|
||||
// Copyright 2019-2026 The NATS Authors
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
@@ -27,7 +27,6 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@@ -115,21 +114,24 @@ type leafNodeCfg struct {
|
||||
perms *Permissions
|
||||
connDelay time.Duration // Delay before a connect, could be used while detecting loop condition, etc..
|
||||
jsMigrateTimer *time.Timer
|
||||
quitCh chan struct{}
|
||||
removed bool
|
||||
connInProgress bool
|
||||
}
|
||||
|
||||
// Check to see if this is a solicited leafnode. We do special processing for solicited.
|
||||
func (c *client) isSolicitedLeafNode() bool {
|
||||
return c.kind == LEAF && c.leaf.remote != nil
|
||||
return c.kind == LEAF && c.leaf != nil && c.leaf.remote != nil
|
||||
}
|
||||
|
||||
// Returns true if this is a solicited leafnode and is not configured to be treated as a hub or a receiving
|
||||
// connection leafnode where the otherside has declared itself to be the hub.
|
||||
func (c *client) isSpokeLeafNode() bool {
|
||||
return c.kind == LEAF && c.leaf.isSpoke
|
||||
return c.kind == LEAF && c.leaf != nil && c.leaf.isSpoke
|
||||
}
|
||||
|
||||
func (c *client) isHubLeafNode() bool {
|
||||
return c.kind == LEAF && !c.leaf.isSpoke
|
||||
return c.kind == LEAF && c.leaf != nil && !c.leaf.isSpoke
|
||||
}
|
||||
|
||||
func (c *client) isIsolatedLeafNode() bool {
|
||||
@@ -137,7 +139,7 @@ func (c *client) isIsolatedLeafNode() bool {
|
||||
// group name here, which the hub and/or leaf could provide, so that we
|
||||
// can isolate away certain LNs but not others on an opt-in basis. For
|
||||
// now we will just isolate all LN interest until then.
|
||||
return c.kind == LEAF && c.leaf.isolated
|
||||
return c.kind == LEAF && c.leaf != nil && c.leaf.isolated
|
||||
}
|
||||
|
||||
// This will spin up go routines to solicit the remote leaf node connections.
|
||||
@@ -152,7 +154,10 @@ func (s *Server) solicitLeafNodeRemotes(remotes []*RemoteLeafOpts) {
|
||||
remote := newLeafNodeCfg(r)
|
||||
creds := remote.Credentials
|
||||
accName := remote.LocalAccount
|
||||
s.leafRemoteCfgs = append(s.leafRemoteCfgs, remote)
|
||||
if s.leafRemoteCfgs == nil {
|
||||
s.leafRemoteCfgs = make(map[*leafNodeCfg]struct{})
|
||||
}
|
||||
s.leafRemoteCfgs[remote] = struct{}{}
|
||||
// Print notice if
|
||||
if isSysAccRemote {
|
||||
if len(remote.DenyExports) > 0 {
|
||||
@@ -192,34 +197,30 @@ func (s *Server) solicitLeafNodeRemotes(remotes []*RemoteLeafOpts) {
|
||||
// configuration required for configuration reload.
|
||||
remote := addRemote(r, r.LocalAccount == sysAccName)
|
||||
if !r.Disabled {
|
||||
s.startGoRoutine(func() { s.connectToRemoteLeafNode(remote, true) })
|
||||
s.connectToRemoteLeafNodeAsynchronously(remote, true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) remoteLeafNodeStillValid(remote *leafNodeCfg) bool {
|
||||
if remote.Disabled {
|
||||
return false
|
||||
}
|
||||
for _, ri := range s.getOpts().LeafNode.Remotes {
|
||||
// FIXME(dlc) - What about auth changes?
|
||||
if reflect.DeepEqual(ri.URLs, remote.URLs) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Ensure that leafnode is properly configured.
|
||||
func validateLeafNode(o *Options) error {
|
||||
if err := validateLeafNodeAuthOptions(o); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Users can bind to any local account, if its empty we will assume the $G account.
|
||||
for _, r := range o.LeafNode.Remotes {
|
||||
if r.LocalAccount == _EMPTY_ {
|
||||
r.LocalAccount = globalAccountName
|
||||
if len(o.LeafNode.Remotes) > 0 {
|
||||
names := make(map[string]struct{})
|
||||
// Check for duplicate remotes, also, users can bind to any local account,
|
||||
// if its empty we will assume the $G account.
|
||||
for _, r := range o.LeafNode.Remotes {
|
||||
if r.LocalAccount == _EMPTY_ {
|
||||
r.LocalAccount = globalAccountName
|
||||
}
|
||||
rn := r.name()
|
||||
if _, dup := names[rn]; dup {
|
||||
return fmt.Errorf("duplicate remote %s", r.safeName())
|
||||
}
|
||||
names[rn] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -428,42 +429,25 @@ func validateLeafNodeProxyOptions(remote *RemoteLeafOpts) ([]string, error) {
|
||||
return warnings, nil
|
||||
}
|
||||
|
||||
// Update remote LeafNode TLS configurations after a config reload.
|
||||
func (s *Server) updateRemoteLeafNodesTLSConfig(opts *Options) {
|
||||
max := len(opts.LeafNode.Remotes)
|
||||
if max == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Changes in the list of remote leaf nodes is not supported.
|
||||
// However, make sure that we don't go over the arrays.
|
||||
if len(s.leafRemoteCfgs) < max {
|
||||
max = len(s.leafRemoteCfgs)
|
||||
}
|
||||
for i := 0; i < max; i++ {
|
||||
ro := opts.LeafNode.Remotes[i]
|
||||
cfg := s.leafRemoteCfgs[i]
|
||||
if ro.TLSConfig != nil {
|
||||
cfg.Lock()
|
||||
cfg.TLSConfig = ro.TLSConfig.Clone()
|
||||
cfg.TLSHandshakeFirst = ro.TLSHandshakeFirst
|
||||
cfg.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for the configured reconnect interval before attempting to connect
|
||||
// again to the remote leafnode.
|
||||
func (s *Server) reConnectToRemoteLeafNode(remote *leafNodeCfg) {
|
||||
clearInProgress := true
|
||||
defer func() {
|
||||
s.grWG.Done()
|
||||
if clearInProgress {
|
||||
remote.setConnectInProgress(false)
|
||||
}
|
||||
}()
|
||||
delay := s.getOpts().LeafNode.ReconnectInterval
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
case <-remote.quitCh:
|
||||
return
|
||||
case <-s.quitCh:
|
||||
s.grWG.Done()
|
||||
return
|
||||
}
|
||||
s.connectToRemoteLeafNode(remote, false)
|
||||
clearInProgress = !connectToRemoteLeafNode(s, remote, false)
|
||||
}
|
||||
|
||||
// Creates a leafNodeCfg object that wraps the RemoteLeafOpts.
|
||||
@@ -471,6 +455,7 @@ func newLeafNodeCfg(remote *RemoteLeafOpts) *leafNodeCfg {
|
||||
cfg := &leafNodeCfg{
|
||||
RemoteLeafOpts: remote,
|
||||
urls: make([]*url.URL, 0, len(remote.URLs)),
|
||||
quitCh: make(chan struct{}, 1),
|
||||
}
|
||||
if len(remote.DenyExports) > 0 || len(remote.DenyImports) > 0 {
|
||||
perms := &Permissions{}
|
||||
@@ -506,6 +491,53 @@ func newLeafNodeCfg(remote *RemoteLeafOpts) *leafNodeCfg {
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Notifies the quit channel without blocking.
|
||||
// No lock is needed to invoke this function.
|
||||
func (cfg *leafNodeCfg) notifyQuitChannel() {
|
||||
select {
|
||||
case cfg.quitCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Sets the connect-in-progress status for this remote leaf configuration.
|
||||
func (cfg *leafNodeCfg) setConnectInProgress(inProgress bool) {
|
||||
cfg.Lock()
|
||||
defer cfg.Unlock()
|
||||
// In both cases we want to drain the "quit" channel.
|
||||
select {
|
||||
case <-cfg.quitCh:
|
||||
default:
|
||||
}
|
||||
cfg.connInProgress = inProgress
|
||||
}
|
||||
|
||||
// Returns `true` if this remote is in the middle of a connect, `false` otherwise.
|
||||
func (cfg *leafNodeCfg) isConnectInProgress() bool {
|
||||
cfg.RLock()
|
||||
defer cfg.RUnlock()
|
||||
return cfg.connInProgress
|
||||
}
|
||||
|
||||
// Mark that this remote is being removed from the configuration.
|
||||
func (cfg *leafNodeCfg) markAsRemoved() {
|
||||
cfg.Lock()
|
||||
defer cfg.Unlock()
|
||||
// This function should be invoked only once, but protect.
|
||||
if cfg.removed {
|
||||
return
|
||||
}
|
||||
cfg.removed = true
|
||||
cfg.notifyQuitChannel()
|
||||
}
|
||||
|
||||
// Returns false if it has been disabled or removed.
|
||||
func (cfg *leafNodeCfg) stillValid() bool {
|
||||
cfg.RLock()
|
||||
defer cfg.RUnlock()
|
||||
return !cfg.Disabled && !cfg.removed
|
||||
}
|
||||
|
||||
// Will pick an URL from the list of available URLs.
|
||||
func (cfg *leafNodeCfg) pickNextURL() *url.URL {
|
||||
cfg.Lock()
|
||||
@@ -622,12 +654,26 @@ func establishHTTPProxyTunnel(proxyURL, targetHost string, timeout time.Duration
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (s *Server) connectToRemoteLeafNode(remote *leafNodeCfg, firstConnect bool) {
|
||||
defer s.grWG.Done()
|
||||
// Connect to a remote leaf node asynchronously (that is, this function will do
|
||||
// the connect in a go routine).
|
||||
func (s *Server) connectToRemoteLeafNodeAsynchronously(remote *leafNodeCfg, firstConnect bool) {
|
||||
remote.setConnectInProgress(true)
|
||||
s.startGoRoutine(func() {
|
||||
defer s.grWG.Done()
|
||||
if !connectToRemoteLeafNode(s, remote, firstConnect) {
|
||||
remote.setConnectInProgress(false)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Connect to a remote leaf node. Should only be invoked from
|
||||
// `s.connectToRemoteLeafNodeAsynchronously()` or `s.reConnectToRemoteLeafNode()`.
|
||||
// Returns `true` if this function invoked `s.createLeafNode()`, false otherwise.
|
||||
func connectToRemoteLeafNode(s *Server, remote *leafNodeCfg, firstConnect bool) bool {
|
||||
|
||||
if remote == nil || len(remote.URLs) == 0 {
|
||||
s.Debugf("Empty remote leafnode definition, nothing to connect")
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
opts := s.getOpts()
|
||||
@@ -651,8 +697,10 @@ func (s *Server) connectToRemoteLeafNode(remote *leafNodeCfg, firstConnect bool)
|
||||
if connDelay := remote.getConnectDelay(); connDelay > 0 {
|
||||
select {
|
||||
case <-time.After(connDelay):
|
||||
case <-remote.quitCh:
|
||||
return false
|
||||
case <-s.quitCh:
|
||||
return
|
||||
return false
|
||||
}
|
||||
remote.setConnectDelay(0)
|
||||
}
|
||||
@@ -676,7 +724,14 @@ func (s *Server) connectToRemoteLeafNode(remote *leafNodeCfg, firstConnect bool)
|
||||
|
||||
attempts := 0
|
||||
|
||||
for s.isRunning() && s.remoteLeafNodeStillValid(remote) {
|
||||
// In case the migrate timer was created but not canceled, do it when
|
||||
// this function exits. Note that the timer would not be created if
|
||||
// `jetstreamMigrateDelay == 0`.
|
||||
if jetstreamMigrateDelay > 0 {
|
||||
defer remote.cancelMigrateTimer()
|
||||
}
|
||||
|
||||
for s.isRunning() && remote.stillValid() {
|
||||
rURL := remote.pickNextURL()
|
||||
url, err := s.getRandomIP(resolver, rURL.Host, nil)
|
||||
if err == nil {
|
||||
@@ -729,8 +784,9 @@ func (s *Server) connectToRemoteLeafNode(remote *leafNodeCfg, firstConnect bool)
|
||||
remote.Unlock()
|
||||
select {
|
||||
case <-s.quitCh:
|
||||
remote.cancelMigrateTimer()
|
||||
return
|
||||
return false
|
||||
case <-remote.quitCh:
|
||||
return false
|
||||
case <-time.After(delay):
|
||||
// Check if we should migrate any JetStream assets immediately while this remote is down.
|
||||
// This will be used if JetStreamClusterMigrateDelay was not set
|
||||
@@ -741,9 +797,11 @@ func (s *Server) connectToRemoteLeafNode(remote *leafNodeCfg, firstConnect bool)
|
||||
}
|
||||
}
|
||||
remote.cancelMigrateTimer()
|
||||
if !s.remoteLeafNodeStillValid(remote) {
|
||||
// We can check here, but really we will have to check again when the server
|
||||
// is about to add to the `s.leafs` map later in the process.
|
||||
if !remote.stillValid() {
|
||||
conn.Close()
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
// We have a connection here to a remote server.
|
||||
@@ -753,8 +811,10 @@ func (s *Server) connectToRemoteLeafNode(remote *leafNodeCfg, firstConnect bool)
|
||||
// Clear any observer states if we had them.
|
||||
s.clearObserverState(remote)
|
||||
|
||||
return
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (cfg *leafNodeCfg) cancelMigrateTimer() {
|
||||
@@ -854,6 +914,8 @@ func (s *Server) isLeafConnectDisabled() bool {
|
||||
// their remote connections did not have a tls{} block).
|
||||
// We now save the host name regardless in case the remote returns an INFO indicating
|
||||
// that TLS is required.
|
||||
//
|
||||
// Lock held on entry.
|
||||
func (cfg *leafNodeCfg) saveTLSHostname(u *url.URL) {
|
||||
if cfg.tlsName == _EMPTY_ && net.ParseIP(u.Hostname()) == nil {
|
||||
cfg.tlsName = u.Hostname()
|
||||
@@ -862,6 +924,8 @@ func (cfg *leafNodeCfg) saveTLSHostname(u *url.URL) {
|
||||
|
||||
// Save off the username/password for when we connect using a bare URL
|
||||
// that we get from the INFO protocol.
|
||||
//
|
||||
// Lock held on entry.
|
||||
func (cfg *leafNodeCfg) saveUserPassword(u *url.URL) {
|
||||
if cfg.username == _EMPTY_ && u.User != nil {
|
||||
cfg.username = u.User.Username()
|
||||
@@ -1459,6 +1523,15 @@ func (c *client) processLeafnodeInfo(info *Info) {
|
||||
|
||||
// Check for compression, unless already done.
|
||||
if firstINFO && !c.flags.isSet(compressionNegotiated) {
|
||||
// A solicited leafnode connection must first receive a leafnode INFO.
|
||||
// Classify wrong-port connections before any leaf-specific negotiation.
|
||||
if didSolicit && (info.CID == 0 || info.LeafNodeURLs == nil) {
|
||||
c.mu.Unlock()
|
||||
c.Errorf(ErrConnectedToWrongPort.Error())
|
||||
c.closeConnection(WrongPort)
|
||||
return
|
||||
}
|
||||
|
||||
// Prevent from getting back here.
|
||||
c.flags.set(compressionNegotiated)
|
||||
|
||||
@@ -1536,15 +1609,6 @@ func (c *client) processLeafnodeInfo(info *Info) {
|
||||
// ** Not if "no advertise" is enabled.
|
||||
// *** Not if leafnode's "no advertise" is enabled.
|
||||
//
|
||||
// As seen from above, a solicited LeafNode connection should receive
|
||||
// from the remote server an INFO with CID and LeafNodeURLs. Anything
|
||||
// else should be considered an attempt to connect to a wrong port.
|
||||
if didSolicit && (info.CID == 0 || info.LeafNodeURLs == nil) {
|
||||
c.mu.Unlock()
|
||||
c.Errorf(ErrConnectedToWrongPort.Error())
|
||||
c.closeConnection(WrongPort)
|
||||
return
|
||||
}
|
||||
// Reject a cluster that contains spaces.
|
||||
if info.Cluster != _EMPTY_ && strings.Contains(info.Cluster, " ") {
|
||||
c.mu.Unlock()
|
||||
@@ -1552,8 +1616,12 @@ func (c *client) processLeafnodeInfo(info *Info) {
|
||||
c.closeConnection(ProtocolViolation)
|
||||
return
|
||||
}
|
||||
// Capture a nonce here.
|
||||
c.nonce = []byte(info.Nonce)
|
||||
// For solicited outbound leaf connections, capture the remote's nonce.
|
||||
// For inbound leaf connections, keep using the server-issued nonce that
|
||||
// was sent in our initial INFO and must be signed in CONNECT.
|
||||
if didSolicit {
|
||||
c.nonce = []byte(info.Nonce)
|
||||
}
|
||||
if info.TLSRequired && didSolicit {
|
||||
remote.TLS = true
|
||||
}
|
||||
@@ -1578,15 +1646,17 @@ func (c *client) processLeafnodeInfo(info *Info) {
|
||||
}
|
||||
|
||||
// For both initial INFO and async INFO protocols, Possibly
|
||||
// update our list of remote leafnode URLs we can connect to.
|
||||
if didSolicit && (len(info.LeafNodeURLs) > 0 || len(info.WSConnectURLs) > 0) {
|
||||
// update our list of remote leafnode URLs we can connect to,
|
||||
// unless we are instructed not to.
|
||||
if didSolicit && !remote.IgnoreDiscoveredServers &&
|
||||
(len(info.LeafNodeURLs) > 0 || len(info.WSConnectURLs) > 0) {
|
||||
// Consider the incoming array as the most up-to-date
|
||||
// representation of the remote cluster's list of URLs.
|
||||
c.updateLeafNodeURLs(info)
|
||||
}
|
||||
|
||||
// Check to see if we have permissions updates here.
|
||||
if info.Import != nil || info.Export != nil {
|
||||
// Only solicited leafnode connections trust permission updates from INFO.
|
||||
if didSolicit && (info.Import != nil || info.Export != nil) {
|
||||
perms := &Permissions{
|
||||
Publish: info.Export,
|
||||
Subscribe: info.Import,
|
||||
@@ -1623,6 +1693,12 @@ func (c *client) processLeafnodeInfo(info *Info) {
|
||||
|
||||
// Check if we have the remote account information and if so make sure it's stored.
|
||||
if info.RemoteAccount != _EMPTY_ {
|
||||
if c.acc == nil {
|
||||
c.mu.Unlock()
|
||||
c.sendErr("Authorization Violation")
|
||||
c.closeConnection(ProtocolViolation)
|
||||
return
|
||||
}
|
||||
s.leafRemoteAccounts.Store(c.acc.Name, info.RemoteAccount)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
@@ -1807,7 +1883,7 @@ func (s *Server) setLeafNodeInfoHostPortAndIP() error {
|
||||
// (this solves the stale connection situation). An error is returned to help the
|
||||
// remote detect the misconfiguration when the duplicate is the result of that
|
||||
// misconfiguration.
|
||||
func (s *Server) addLeafNodeConnection(c *client, srvName, clusterName string, checkForDup bool) {
|
||||
func (s *Server) addLeafNodeConnection(c *client, srvName, clusterName string, checkForDup bool) bool {
|
||||
var accName string
|
||||
c.mu.Lock()
|
||||
cid := c.cid
|
||||
@@ -1819,7 +1895,8 @@ func (s *Server) addLeafNodeConnection(c *client, srvName, clusterName string, c
|
||||
mySrvName := c.leaf.remoteServer
|
||||
remoteAccName := c.leaf.remoteAccName
|
||||
myClustName := c.leaf.remoteCluster
|
||||
solicited := c.leaf.remote != nil
|
||||
remote := c.leaf.remote
|
||||
solicited := remote != nil
|
||||
c.mu.Unlock()
|
||||
|
||||
var old *client
|
||||
@@ -1843,6 +1920,23 @@ func (s *Server) addLeafNodeConnection(c *client, srvName, clusterName string, c
|
||||
}
|
||||
}
|
||||
}
|
||||
// Now that we are under the server lock and before adding it to the map,
|
||||
// for a solicited leaf, we need to make sure that it has not been removed
|
||||
// from the config or disabled.
|
||||
if solicited {
|
||||
// If no longer valid, do not add to the server map. The connection
|
||||
// should have been marked so that it can't reconnect. When the caller
|
||||
// calls closeConnection(), cleanup (including clearing the connect-
|
||||
// in-progress flag) will occur at the appropriate time.
|
||||
if !remote.stillValid() {
|
||||
// Prevent reconnect in case it was not yet done.
|
||||
c.setNoReconnect()
|
||||
s.mu.Unlock()
|
||||
s.removeFromTempClients(cid)
|
||||
return false
|
||||
}
|
||||
remote.setConnectInProgress(false)
|
||||
}
|
||||
// Store new connection in the map
|
||||
s.leafs[cid] = c
|
||||
s.mu.Unlock()
|
||||
@@ -1891,7 +1985,7 @@ func (s *Server) addLeafNodeConnection(c *client, srvName, clusterName string, c
|
||||
} else if domain, ok := opts.JsAccDefaultDomain[accName]; ok && domain == _EMPTY_ {
|
||||
// for backwards compatibility with old setups that do not have a domain name set
|
||||
c.Debugf("Skipping deny %q for account %q due to default domain", jsAllAPI, accName)
|
||||
return
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1969,9 +2063,11 @@ func (s *Server) addLeafNodeConnection(c *client, srvName, clusterName string, c
|
||||
c.Debugf("Adding deny %q for outgoing messages to account %q", src, accName)
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) removeLeafNodeConnection(c *client) {
|
||||
s.mu.Lock()
|
||||
c.mu.Lock()
|
||||
cid := c.cid
|
||||
if c.leaf != nil {
|
||||
@@ -1984,10 +2080,18 @@ func (s *Server) removeLeafNodeConnection(c *client) {
|
||||
// We need to set this to nil for GC to release the connection
|
||||
c.leaf.gwSub = nil
|
||||
}
|
||||
if remote := c.leaf.remote; remote != nil {
|
||||
// If "noReconnect" is true, then we won't attempt to reconnect, so
|
||||
// we will clear the "connect-in-progress" flag. However, if we can
|
||||
// reconnect, then we should set "connect-in-progress" to true while
|
||||
// we are under the server/client lock. The go routine that performs
|
||||
// the reconnect will be started later and there would be a gap with
|
||||
// the wrong flag value otherwise.
|
||||
remote.setConnectInProgress(!c.flags.isSet(noReconnect))
|
||||
}
|
||||
}
|
||||
proxyKey := c.proxyKey
|
||||
c.mu.Unlock()
|
||||
s.mu.Lock()
|
||||
delete(s.leafs, cid)
|
||||
if proxyKey != _EMPTY_ {
|
||||
s.removeProxiedConn(proxyKey, cid)
|
||||
@@ -2154,6 +2258,13 @@ func (c *client) processLeafNodeConnect(s *Server, arg []byte, lang string) erro
|
||||
acc := c.acc
|
||||
c.mu.Unlock()
|
||||
|
||||
// If the account is not set (e.g. connection was closed due to auth
|
||||
// timeout while still being processed), bail out to avoid a panic.
|
||||
if acc == nil {
|
||||
c.closeConnection(MissingAccount)
|
||||
return ErrMissingAccount
|
||||
}
|
||||
|
||||
// Register the cluster, even if empty, as long as we are acting as a hub.
|
||||
if !proto.Hub {
|
||||
acc.registerLeafNodeCluster(proto.Cluster)
|
||||
@@ -2999,6 +3110,11 @@ func (c *client) processLeafHeaderMsgArgs(arg []byte) error {
|
||||
if c.pa.hdr > c.pa.size {
|
||||
return fmt.Errorf("processLeafHeaderMsgArgs Header Size larger then TotalSize: '%s'", arg)
|
||||
}
|
||||
maxPayload := atomic.LoadInt32(&c.mpay)
|
||||
if maxPayload != jwt.NoLimit && int64(c.pa.size) > int64(maxPayload) {
|
||||
c.maxPayloadViolation(c.pa.size, maxPayload)
|
||||
return ErrMaxPayload
|
||||
}
|
||||
|
||||
// Common ones processed after check for arg length
|
||||
c.pa.subject = args[0]
|
||||
@@ -3068,6 +3184,11 @@ func (c *client) processLeafMsgArgs(arg []byte) error {
|
||||
if c.pa.size < 0 {
|
||||
return fmt.Errorf("processLeafMsgArgs Bad or Missing Size: '%s'", args)
|
||||
}
|
||||
maxPayload := atomic.LoadInt32(&c.mpay)
|
||||
if maxPayload != jwt.NoLimit && int64(c.pa.size) > int64(maxPayload) {
|
||||
c.maxPayloadViolation(c.pa.size, maxPayload)
|
||||
return ErrMaxPayload
|
||||
}
|
||||
|
||||
// Common ones processed after check for arg length
|
||||
c.pa.subject = args[0]
|
||||
@@ -3089,6 +3210,12 @@ func (c *client) processInboundLeafMsg(msg []byte) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check that leaf messages respect the subject permissions.
|
||||
if c.perms != nil && !c.leafMsgAllowed() {
|
||||
c.leafPubPermViolation(c.pa.subject)
|
||||
return
|
||||
}
|
||||
|
||||
// Match the subscriptions. We will use our own L1 map if
|
||||
// it's still valid, avoiding contention on the shared sublist.
|
||||
var r *SublistResult
|
||||
@@ -3150,12 +3277,102 @@ func (c *client) processInboundLeafMsg(msg []byte) {
|
||||
}
|
||||
}
|
||||
|
||||
// Checks whether the inbound leaf message is allowed by the
|
||||
// connection's permissions. On the hub side this enforces what
|
||||
// the remote leaf may publish. On the spoke side this enforces
|
||||
// import restrictions such as deny_imports.
|
||||
func (c *client) leafMsgAllowed() bool {
|
||||
wireSubject := c.pa.subject
|
||||
if len(c.pa.mapped) > 0 {
|
||||
// Mappings rewrite c.pa.subject to the internal
|
||||
// destination. For leaf ACLs, need to check
|
||||
// the original wire subject from the remote side.
|
||||
wireSubject = c.pa.mapped
|
||||
}
|
||||
// Strip any gateway routing prefix for the permission check.
|
||||
subjectToCheck, isGW := getGWRoutedSubjectOrSelf(wireSubject)
|
||||
|
||||
// Service-import replies (_R_), JS ack subjects ($JS.ACK.)
|
||||
// are internal routing subjects forwarded via LS+ without
|
||||
// permission checks.
|
||||
if isServiceReply(subjectToCheck) || isJSAckSubject(subjectToCheck) {
|
||||
return true
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.isSpokeLeafNode() {
|
||||
// Gateway routed replies are forwarded without
|
||||
// permission checks.
|
||||
if isGW || c.leafReceiveAllowed(subjectToCheck) {
|
||||
return true
|
||||
}
|
||||
} else if c.leafSendAllowed(subjectToCheck) {
|
||||
return true
|
||||
}
|
||||
// Check tracked reply permissions (allow_responses).
|
||||
// Use the pre-strip subject since deliverMsg tracks
|
||||
// replies under the original form, which includes
|
||||
// the GW routing prefix for routed requests.
|
||||
return c.responseAllowed(bytesToString(wireSubject))
|
||||
}
|
||||
|
||||
// Returns true if the leaf side ACLs allow importing this subject,
|
||||
// based on the permissions received over INFO and any local deny_imports.
|
||||
// Lock must be held.
|
||||
func (c *client) leafReceiveAllowed(subject []byte) bool {
|
||||
return c.canSubscribe(bytesToString(subject))
|
||||
}
|
||||
|
||||
// Returns true if the hub side ACLs allow the remote leaf to send
|
||||
// this subject.
|
||||
// Lock must be held.
|
||||
func (c *client) leafSendAllowed(bsubject []byte) bool {
|
||||
// Use the original export ACL captured for this accepted leaf.
|
||||
// The live perms also contain additional JetStream denies used by
|
||||
// the normal forwarding path, and applying them here would reject
|
||||
// legitimate inbound JS API requests.
|
||||
subject := bytesToString(bsubject)
|
||||
perms := c.opts.Export
|
||||
if perms == nil || (perms.Allow == nil && perms.Deny == nil) {
|
||||
return true
|
||||
}
|
||||
|
||||
allowed := true
|
||||
if perms.Allow != nil && !strings.HasPrefix(subject, mqttPrefix) {
|
||||
allowed = false
|
||||
for _, allowSubj := range perms.Allow {
|
||||
if matchLiteral(subject, allowSubj) {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if allowed && len(perms.Deny) > 0 {
|
||||
for _, denySubj := range perms.Deny {
|
||||
if matchLiteral(subject, denySubj) {
|
||||
allowed = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return allowed
|
||||
}
|
||||
|
||||
// Handles a subscription permission violation.
|
||||
// See leafPermViolation() for details.
|
||||
func (c *client) leafSubPermViolation(subj []byte) {
|
||||
c.leafPermViolation(false, subj)
|
||||
}
|
||||
|
||||
// Handles a publish permission violation.
|
||||
// See leafPermViolation() for details.
|
||||
func (c *client) leafPubPermViolation(subj []byte) {
|
||||
c.leafPermViolation(true, subj)
|
||||
}
|
||||
|
||||
// Common function to process publish or subscribe leafnode permission violation.
|
||||
// Sends the permission violation error to the remote, logs it and closes the connection.
|
||||
// If this is from a server soliciting, the reconnection will be delayed.
|
||||
@@ -3433,6 +3650,12 @@ func (s *Server) leafNodeFinishConnectProcess(c *client) {
|
||||
return
|
||||
}
|
||||
remote := c.leaf.remote
|
||||
if remote == nil || c.acc == nil {
|
||||
c.mu.Unlock()
|
||||
c.sendErr("Authorization Violation")
|
||||
c.closeConnection(ProtocolViolation)
|
||||
return
|
||||
}
|
||||
// Check if we will need to send the system connect event.
|
||||
remote.RLock()
|
||||
sendSysConnectEvent := remote.Hub
|
||||
@@ -3457,19 +3680,22 @@ func (s *Server) leafNodeFinishConnectProcess(c *client) {
|
||||
c.closeConnection(ProtocolViolation)
|
||||
return
|
||||
}
|
||||
s.addLeafNodeConnection(c, _EMPTY_, _EMPTY_, false)
|
||||
if !s.addLeafNodeConnection(c, _EMPTY_, _EMPTY_, false) {
|
||||
// Was not added, could be because the remote configuration has been removed.
|
||||
c.closeConnection(ClientClosed)
|
||||
return
|
||||
}
|
||||
s.initLeafNodeSmapAndSendSubs(c)
|
||||
if sendSysConnectEvent {
|
||||
s.sendLeafNodeConnect(acc)
|
||||
}
|
||||
s.accountConnectEvent(c)
|
||||
|
||||
// The above functions are not atomically under the client
|
||||
// lock doing those operations. It is possible - since we
|
||||
// have started the read/write loops - that the connection
|
||||
// is closed before or in between. This would leave the
|
||||
// closed LN connection possible registered with the account
|
||||
// and/or the server's leafs map. So check if connection
|
||||
// is closed, and if so, manually cleanup.
|
||||
// The above functions are not running under the client lock, so it is
|
||||
// possible that between the time we have started the read/write loops
|
||||
// and now, that the connection was closed. This would leave the closed
|
||||
// LN connection possibly registered with the account and/or the server's
|
||||
// leafs map. So check if connection is closed, and if so, manually cleanup.
|
||||
c.mu.Lock()
|
||||
closed := c.isClosed()
|
||||
if !closed {
|
||||
|
||||
+8
@@ -227,6 +227,14 @@ func (s *Server) rateLimitFormatWarnf(format string, v ...any) {
|
||||
s.Warnf("%s", statement)
|
||||
}
|
||||
|
||||
func (s *Server) RateLimitErrorf(format string, v ...any) {
|
||||
statement := fmt.Sprintf(format, v...)
|
||||
if _, loaded := s.rateLimitLogging.LoadOrStore(statement, time.Now()); loaded {
|
||||
return
|
||||
}
|
||||
s.Errorf("%s", statement)
|
||||
}
|
||||
|
||||
func (s *Server) RateLimitWarnf(format string, v ...any) {
|
||||
statement := fmt.Sprintf(format, v...)
|
||||
if _, loaded := s.rateLimitLogging.LoadOrStore(statement, time.Now()); loaded {
|
||||
|
||||
+137
-67
@@ -184,7 +184,7 @@ func (ms *memStore) recoverMsgSchedulingState() {
|
||||
if len(sm.hdr) == 0 {
|
||||
continue
|
||||
}
|
||||
if schedule, ok := getMessageSchedule(sm.hdr); ok && !schedule.IsZero() {
|
||||
if schedule, apiErr := nextMessageSchedule(sm.hdr, sm.ts); apiErr == nil && !schedule.IsZero() {
|
||||
ms.scheduling.init(seq, sm.subj, schedule.UnixNano())
|
||||
}
|
||||
}
|
||||
@@ -192,7 +192,7 @@ func (ms *memStore) recoverMsgSchedulingState() {
|
||||
|
||||
// Stores a raw message with expected sequence number and timestamp.
|
||||
// Lock should be held.
|
||||
func (ms *memStore) storeRawMsg(subj string, hdr, msg []byte, seq uint64, ts, ttl int64) error {
|
||||
func (ms *memStore) storeRawMsg(subj string, hdr, msg []byte, seq uint64, ts, ttl int64, discardNewCheck bool) error {
|
||||
if ms.msgs == nil {
|
||||
return ErrStoreClosed
|
||||
}
|
||||
@@ -208,31 +208,31 @@ func (ms *memStore) storeRawMsg(subj string, hdr, msg []byte, seq uint64, ts, tt
|
||||
}
|
||||
|
||||
// Check if we are discarding new messages when we reach the limit.
|
||||
if ms.cfg.Discard == DiscardNew {
|
||||
if asl && ms.cfg.DiscardNewPer {
|
||||
// If we are clustered, we do the enforcement above and should not disqualify
|
||||
// the message here since it could cause replicas to drift.
|
||||
if discardNewCheck && ms.cfg.Discard == DiscardNew {
|
||||
// Allow rollup messages through since they will purge old
|
||||
// messages for the subject after storing, restoring the limit.
|
||||
if asl && ms.cfg.DiscardNewPer && len(sliceHeader(JSMsgRollup, hdr)) == 0 {
|
||||
return ErrMaxMsgsPerSubject
|
||||
}
|
||||
// If we are discard new and limits policy and clustered, we do the enforcement
|
||||
// above and should not disqualify the message here since it could cause replicas to drift.
|
||||
if ms.cfg.Retention == LimitsPolicy || ms.cfg.Replicas == 1 {
|
||||
if ms.cfg.MaxMsgs > 0 && ms.state.Msgs >= uint64(ms.cfg.MaxMsgs) {
|
||||
// If we are tracking max messages per subject and are at the limit we will replace, so this is ok.
|
||||
if !asl {
|
||||
return ErrMaxMsgs
|
||||
}
|
||||
if ms.cfg.MaxMsgs > 0 && ms.state.Msgs >= uint64(ms.cfg.MaxMsgs) {
|
||||
// If we are tracking max messages per subject and are at the limit we will replace, so this is ok.
|
||||
if !asl {
|
||||
return ErrMaxMsgs
|
||||
}
|
||||
if ms.cfg.MaxBytes > 0 && ms.state.Bytes+memStoreMsgSize(subj, hdr, msg) >= uint64(ms.cfg.MaxBytes) {
|
||||
if !asl {
|
||||
return ErrMaxBytes
|
||||
}
|
||||
// If we are here we are at a subject maximum, need to determine if dropping last message gives us enough room.
|
||||
if ss.firstNeedsUpdate || ss.lastNeedsUpdate {
|
||||
ms.recalculateForSubj(subj, ss)
|
||||
}
|
||||
sm, ok := ms.msgs[ss.First]
|
||||
if !ok || memStoreMsgSize(sm.subj, sm.hdr, sm.msg) < memStoreMsgSize(subj, hdr, msg) {
|
||||
return ErrMaxBytes
|
||||
}
|
||||
}
|
||||
if ms.cfg.MaxBytes > 0 && ms.state.Bytes+memStoreMsgSize(subj, hdr, msg) > uint64(ms.cfg.MaxBytes) {
|
||||
if !asl {
|
||||
return ErrMaxBytes
|
||||
}
|
||||
// If we are here we are at a subject maximum, need to determine if dropping last message gives us enough room.
|
||||
if ss.firstNeedsUpdate || ss.lastNeedsUpdate {
|
||||
ms.recalculateForSubj(subj, ss)
|
||||
}
|
||||
sm, ok := ms.msgs[ss.First]
|
||||
if !ok || memStoreMsgSize(sm.subj, sm.hdr, sm.msg) < memStoreMsgSize(subj, hdr, msg) {
|
||||
return ErrMaxBytes
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -309,20 +309,28 @@ func (ms *memStore) storeRawMsg(subj string, hdr, msg []byte, seq uint64, ts, tt
|
||||
|
||||
// Message scheduling.
|
||||
if ms.scheduling != nil {
|
||||
if schedule, ok := getMessageSchedule(hdr); ok && !schedule.IsZero() {
|
||||
if schedule, apiErr := nextMessageSchedule(hdr, ts); apiErr == nil && !schedule.IsZero() {
|
||||
ms.scheduling.add(seq, subj, schedule.UnixNano())
|
||||
} else {
|
||||
} else if getMessageScheduler(hdr) == _EMPTY_ {
|
||||
ms.scheduling.removeSubject(subj)
|
||||
}
|
||||
|
||||
// Check for a repeating schedule and update such that it triggers again.
|
||||
if scheduleNext := bytesToString(sliceHeader(JSScheduleNext, hdr)); scheduleNext != _EMPTY_ && scheduleNext != JSScheduleNextPurge {
|
||||
scheduler := getMessageScheduler(hdr)
|
||||
if next, err := time.Parse(time.RFC3339Nano, scheduleNext); err == nil && scheduler != _EMPTY_ {
|
||||
ms.scheduling.update(scheduler, next.UnixNano())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StoreRawMsg stores a raw message with expected sequence number and timestamp.
|
||||
func (ms *memStore) StoreRawMsg(subj string, hdr, msg []byte, seq uint64, ts, ttl int64) error {
|
||||
func (ms *memStore) StoreRawMsg(subj string, hdr, msg []byte, seq uint64, ts, ttl int64, discardNewCheck bool) error {
|
||||
ms.mu.Lock()
|
||||
err := ms.storeRawMsg(subj, hdr, msg, seq, ts, ttl)
|
||||
err := ms.storeRawMsg(subj, hdr, msg, seq, ts, ttl, discardNewCheck)
|
||||
cb := ms.scb
|
||||
// Check if first message timestamp requires expiry
|
||||
// sooner than initial replica expiry timer set to MaxAge when initializing.
|
||||
@@ -344,7 +352,8 @@ func (ms *memStore) StoreRawMsg(subj string, hdr, msg []byte, seq uint64, ts, tt
|
||||
func (ms *memStore) StoreMsg(subj string, hdr, msg []byte, ttl int64) (uint64, int64, error) {
|
||||
ms.mu.Lock()
|
||||
seq, ts := ms.state.LastSeq+1, time.Now().UnixNano()
|
||||
err := ms.storeRawMsg(subj, hdr, msg, seq, ts, ttl)
|
||||
// This is called for a R1 with no expected sequence number, so perform DiscardNew checks on the store-level.
|
||||
err := ms.storeRawMsg(subj, hdr, msg, seq, ts, ttl, true)
|
||||
cb := ms.scb
|
||||
ms.mu.Unlock()
|
||||
|
||||
@@ -414,8 +423,9 @@ func (ms *memStore) SkipMsgs(seq uint64, num uint64) error {
|
||||
}
|
||||
|
||||
// FlushAllPending flushes all data that was still pending to be written.
|
||||
func (ms *memStore) FlushAllPending() {
|
||||
func (ms *memStore) FlushAllPending() error {
|
||||
// Noop, in-memory store doesn't use async applying.
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterStorageUpdates registers a callback for updates to storage changes.
|
||||
@@ -521,13 +531,13 @@ loop:
|
||||
}
|
||||
|
||||
// FilteredState will return the SimpleState associated with the filtered subject and a proposed starting sequence.
|
||||
func (ms *memStore) FilteredState(sseq uint64, subj string) SimpleState {
|
||||
func (ms *memStore) FilteredState(sseq uint64, subj string) (SimpleState, error) {
|
||||
// This needs to be a write lock, as filteredStateLocked can
|
||||
// mutate the per-subject state.
|
||||
ms.mu.Lock()
|
||||
defer ms.mu.Unlock()
|
||||
|
||||
return ms.filteredStateLocked(sseq, subj, false)
|
||||
return ms.filteredStateLocked(sseq, subj, false), nil
|
||||
}
|
||||
|
||||
func (ms *memStore) filteredStateLocked(sseq uint64, filter string, lastPerSubject bool) SimpleState {
|
||||
@@ -1391,10 +1401,16 @@ func (ms *memStore) runMsgScheduling() {
|
||||
}
|
||||
ms.scheduling.running = true
|
||||
|
||||
scheduledMsgs := ms.scheduling.getScheduledMessages(func(seq uint64, smv *StoreMsg) *StoreMsg {
|
||||
sm, _ := ms.loadMsgLocked(seq, smv, false)
|
||||
return sm
|
||||
})
|
||||
scheduledMsgs := ms.scheduling.getScheduledMessages(
|
||||
func(seq uint64, smv *StoreMsg) *StoreMsg {
|
||||
sm, _ := ms.loadMsgLocked(seq, smv, false)
|
||||
return sm
|
||||
},
|
||||
func(subj string, smv *StoreMsg) *StoreMsg {
|
||||
sm, _ := ms.loadLastLocked(subj, smv)
|
||||
return sm
|
||||
},
|
||||
)
|
||||
if len(scheduledMsgs) > 0 {
|
||||
ms.mu.Unlock()
|
||||
for _, msg := range scheduledMsgs {
|
||||
@@ -1429,7 +1445,7 @@ func (ms *memStore) PurgeEx(subject string, sequence, keep uint64) (purged uint6
|
||||
|
||||
}
|
||||
eq := compareFn(subject)
|
||||
if ss := ms.FilteredState(1, subject); ss.Msgs > 0 {
|
||||
if ss, _ := ms.FilteredState(1, subject); ss.Msgs > 0 {
|
||||
if keep > 0 {
|
||||
if keep >= ss.Msgs {
|
||||
return 0, nil
|
||||
@@ -1712,13 +1728,17 @@ func (ms *memStore) loadMsgLocked(seq uint64, smp *StoreMsg, needMSLock bool) (*
|
||||
// LoadLastMsg will return the last message we have that matches a given subject.
|
||||
// The subject can be a wildcard.
|
||||
func (ms *memStore) LoadLastMsg(subject string, smp *StoreMsg) (*StoreMsg, error) {
|
||||
var sm *StoreMsg
|
||||
var ok bool
|
||||
|
||||
// This needs to be a write lock, as filteredStateLocked can
|
||||
// mutate the per-subject state.
|
||||
ms.mu.Lock()
|
||||
defer ms.mu.Unlock()
|
||||
return ms.loadLastLocked(subject, smp)
|
||||
}
|
||||
|
||||
// Lock should be held.
|
||||
func (ms *memStore) loadLastLocked(subject string, smp *StoreMsg) (*StoreMsg, error) {
|
||||
var sm *StoreMsg
|
||||
var ok bool
|
||||
|
||||
if subject == _EMPTY_ || subject == fwcs {
|
||||
sm, ok = ms.msgs[ms.state.LastSeq]
|
||||
@@ -1907,31 +1927,41 @@ func (ms *memStore) loadNextMsgLocked(filter string, wc bool, start uint64, smp
|
||||
return nil, ms.state.LastSeq, ErrStoreEOF
|
||||
}
|
||||
|
||||
// Will load the next non-deleted msg starting at the start sequence and walking backwards.
|
||||
func (ms *memStore) LoadPrevMsg(start uint64, smp *StoreMsg) (sm *StoreMsg, err error) {
|
||||
// Will load the previous message matching the filter subject, starting at the start sequence and walking backwards.
|
||||
func (ms *memStore) LoadPrevMsg(filter string, wc bool, start uint64, smp *StoreMsg) (sm *StoreMsg, skip uint64, err error) {
|
||||
ms.mu.RLock()
|
||||
defer ms.mu.RUnlock()
|
||||
|
||||
if ms.msgs == nil {
|
||||
return nil, ErrStoreClosed
|
||||
return nil, 0, ErrStoreClosed
|
||||
}
|
||||
if ms.state.Msgs == 0 || start < ms.state.FirstSeq {
|
||||
return nil, ErrStoreEOF
|
||||
return nil, ms.state.FirstSeq, ErrStoreEOF
|
||||
}
|
||||
if start > ms.state.LastSeq {
|
||||
start = ms.state.LastSeq
|
||||
}
|
||||
|
||||
if filter == _EMPTY_ {
|
||||
filter = fwcs
|
||||
wc = true
|
||||
}
|
||||
isAll := filter == fwcs
|
||||
eq := subjectsEqual
|
||||
if wc {
|
||||
eq = matchLiteral
|
||||
}
|
||||
|
||||
for seq := start; seq >= ms.state.FirstSeq; seq-- {
|
||||
if sm, ok := ms.msgs[seq]; ok {
|
||||
if sm, ok := ms.msgs[seq]; ok && (isAll || eq(sm.subj, filter)) {
|
||||
if smp == nil {
|
||||
smp = new(StoreMsg)
|
||||
}
|
||||
sm.copy(smp)
|
||||
return smp, nil
|
||||
return smp, seq, nil
|
||||
}
|
||||
}
|
||||
return nil, ErrStoreEOF
|
||||
return nil, ms.state.FirstSeq, ErrStoreEOF
|
||||
}
|
||||
|
||||
// LoadPrevMsgMulti will find the previous message matching any entry in the sublist.
|
||||
@@ -1965,7 +1995,7 @@ func (ms *memStore) LoadPrevMsgMulti(sl *gsl.SimpleSublist, start uint64, smp *S
|
||||
return smp, nseq, nil
|
||||
}
|
||||
}
|
||||
return nil, ms.state.LastSeq, ErrStoreEOF
|
||||
return nil, ms.state.FirstSeq, ErrStoreEOF
|
||||
}
|
||||
|
||||
// RemoveMsg will remove the message from this store.
|
||||
@@ -2329,10 +2359,7 @@ func (ms *memStore) EncodedStreamState(failed uint64) ([]byte, error) {
|
||||
b := buf[0:n]
|
||||
|
||||
if numDeleted > 0 {
|
||||
buf, err := ms.dmap.Encode(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf := ms.dmap.Encode(nil)
|
||||
b = append(b, buf...)
|
||||
}
|
||||
|
||||
@@ -2340,7 +2367,11 @@ func (ms *memStore) EncodedStreamState(failed uint64) ([]byte, error) {
|
||||
}
|
||||
|
||||
// SyncDeleted will make sure this stream has same deleted state as dbs.
|
||||
func (ms *memStore) SyncDeleted(dbs DeleteBlocks) {
|
||||
func (ms *memStore) SyncDeleted(dbs DeleteBlocks) error {
|
||||
if len(dbs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ms.mu.Lock()
|
||||
defer ms.mu.Unlock()
|
||||
|
||||
@@ -2349,7 +2380,7 @@ func (ms *memStore) SyncDeleted(dbs DeleteBlocks) {
|
||||
if len(dbs) == 1 {
|
||||
min, max, num := ms.dmap.State()
|
||||
if pmin, pmax, pnum := dbs[0].State(); pmin == min && pmax == max && pnum == num {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
}
|
||||
lseq := ms.state.LastSeq
|
||||
@@ -2363,6 +2394,7 @@ func (ms *memStore) SyncDeleted(dbs DeleteBlocks) {
|
||||
return true
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *consumerMemStore) Update(state *ConsumerState) error {
|
||||
@@ -2410,10 +2442,51 @@ func (o *consumerMemStore) Update(state *ConsumerState) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *consumerMemStore) ForceUpdate(state *ConsumerState) error {
|
||||
// Sanity checks.
|
||||
if state.AckFloor.Consumer > state.Delivered.Consumer {
|
||||
return fmt.Errorf("bad ack floor for consumer")
|
||||
}
|
||||
if state.AckFloor.Stream > state.Delivered.Stream {
|
||||
return fmt.Errorf("bad ack floor for stream")
|
||||
}
|
||||
|
||||
// Copy to our state.
|
||||
var pending map[uint64]*Pending
|
||||
var redelivered map[uint64]uint64
|
||||
if len(state.Pending) > 0 {
|
||||
pending = make(map[uint64]*Pending, len(state.Pending))
|
||||
for seq, p := range state.Pending {
|
||||
pending[seq] = &Pending{p.Sequence, p.Timestamp}
|
||||
if seq <= state.AckFloor.Stream || seq > state.Delivered.Stream {
|
||||
return fmt.Errorf("bad pending entry, sequence [%d] out of range", seq)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(state.Redelivered) > 0 {
|
||||
redelivered = make(map[uint64]uint64, len(state.Redelivered))
|
||||
for seq, dc := range state.Redelivered {
|
||||
redelivered[seq] = dc
|
||||
}
|
||||
}
|
||||
|
||||
// Replace our state.
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
|
||||
o.state.Delivered = state.Delivered
|
||||
o.state.AckFloor = state.AckFloor
|
||||
o.state.Pending = pending
|
||||
o.state.Redelivered = redelivered
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetStarting sets our starting stream sequence.
|
||||
func (o *consumerMemStore) SetStarting(sseq uint64) error {
|
||||
o.mu.Lock()
|
||||
o.state.Delivered.Stream = sseq
|
||||
o.state.AckFloor.Stream = sseq
|
||||
o.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
@@ -2432,6 +2505,14 @@ func (o *consumerMemStore) UpdateStarting(sseq uint64) {
|
||||
}
|
||||
}
|
||||
|
||||
// Reset all values in the store, and reset the starting sequence.
|
||||
func (o *consumerMemStore) Reset(sseq uint64) error {
|
||||
o.mu.Lock()
|
||||
o.state = ConsumerState{}
|
||||
o.mu.Unlock()
|
||||
return o.SetStarting(sseq)
|
||||
}
|
||||
|
||||
// HasState returns if this store has a recorded state.
|
||||
func (o *consumerMemStore) HasState() bool {
|
||||
o.mu.Lock()
|
||||
@@ -2524,8 +2605,8 @@ func (o *consumerMemStore) UpdateAcks(dseq, sseq uint64) error {
|
||||
return ErrStoreMsgNotFound
|
||||
}
|
||||
|
||||
// Check for AckAll here.
|
||||
if o.cfg.AckPolicy == AckAll {
|
||||
// Check for AckAll here (or AckFlowControl which functions like AckAll).
|
||||
if o.cfg.AckPolicy == AckAll || o.cfg.AckPolicy == AckFlowControl {
|
||||
sgap := sseq - o.state.AckFloor.Stream
|
||||
o.state.AckFloor.Consumer = dseq
|
||||
o.state.AckFloor.Stream = sseq
|
||||
@@ -2675,14 +2756,3 @@ func (o *consumerMemStore) copyRedelivered() map[uint64]uint64 {
|
||||
|
||||
// Type returns the type of the underlying store.
|
||||
func (o *consumerMemStore) Type() StorageType { return MemoryStorage }
|
||||
|
||||
// Templates
|
||||
type templateMemStore struct{}
|
||||
|
||||
func newTemplateMemStore() *templateMemStore {
|
||||
return &templateMemStore{}
|
||||
}
|
||||
|
||||
// No-ops for memstore.
|
||||
func (ts *templateMemStore) Store(t *streamTemplate) error { return nil }
|
||||
func (ts *templateMemStore) Delete(t *streamTemplate) error { return nil }
|
||||
|
||||
+101
-61
@@ -189,6 +189,17 @@ func newSubsList(client *client) []string {
|
||||
return subs
|
||||
}
|
||||
|
||||
func redactBearerJWT(userJWT string) string {
|
||||
if userJWT == _EMPTY_ {
|
||||
return _EMPTY_
|
||||
}
|
||||
uc, err := jwt.DecodeUserClaims(userJWT)
|
||||
if err == nil && uc != nil && uc.BearerToken {
|
||||
return _EMPTY_
|
||||
}
|
||||
return userJWT
|
||||
}
|
||||
|
||||
// Connz returns a Connz struct containing information about connections.
|
||||
func (s *Server) Connz(opts *ConnzOptions) (*Connz, error) {
|
||||
var (
|
||||
@@ -441,6 +452,7 @@ func (s *Server) Connz(opts *ConnzOptions) (*Connz, error) {
|
||||
ci.NameTag = client.acc.getNameTag()
|
||||
}
|
||||
client.mu.Unlock()
|
||||
ci.JWT = redactBearerJWT(ci.JWT)
|
||||
pconns[i] = ci
|
||||
i++
|
||||
}
|
||||
@@ -487,6 +499,7 @@ func (s *Server) Connz(opts *ConnzOptions) (*Connz, error) {
|
||||
cc.NameTag = acc.getNameTag()
|
||||
}
|
||||
}
|
||||
cc.JWT = redactBearerJWT(cc.JWT)
|
||||
}
|
||||
pconns[i] = &cc.ConnInfo
|
||||
i++
|
||||
@@ -1271,6 +1284,7 @@ type Varz struct {
|
||||
ConfigDigest string `json:"config_digest"` // ConfigDigest is a calculated hash of the current configuration
|
||||
Tags jwt.TagList `json:"tags,omitempty"` // Tags are the tags assigned to the server in configuration
|
||||
Metadata map[string]string `json:"metadata,omitempty"` // Metadata is the metadata assigned to the server in configuration
|
||||
FeatureFlags map[string]bool `json:"feature_flags,omitempty"` // FeatureFlags is the feature flags enabled/disabled in configuration
|
||||
TrustedOperatorsJwt []string `json:"trusted_operators_jwt,omitempty"` // TrustedOperatorsJwt is the JWTs for all trusted operators
|
||||
TrustedOperatorsClaim []*jwt.OperatorClaims `json:"trusted_operators_claim,omitempty"` // TrustedOperatorsClaim is the decoded claims for each trusted operator
|
||||
SystemAccount string `json:"system_account,omitempty"` // SystemAccount is the name of the System account
|
||||
@@ -1570,8 +1584,13 @@ func (s *Server) updateJszVarz(js *jetStream, v *JetStreamVarz, doConfig bool) {
|
||||
v.Meta.Replicas = ci.Replicas
|
||||
}
|
||||
if ipq := s.jsAPIRoutedReqs; ipq != nil {
|
||||
v.Meta.Pending = ipq.len()
|
||||
v.Meta.PendingRequests = ipq.len()
|
||||
}
|
||||
if ipq := s.jsAPIRoutedInfoReqs; ipq != nil {
|
||||
v.Meta.PendingInfos = ipq.len()
|
||||
}
|
||||
v.Meta.Pending = v.Meta.PendingRequests + v.Meta.PendingInfos
|
||||
v.Meta.Snapshot = s.metaClusterSnapshotStats(js, mg)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1788,6 +1807,7 @@ func (s *Server) updateVarzConfigReloadableFields(v *Varz) {
|
||||
v.ConfigDigest = opts.configDigest
|
||||
v.Tags = opts.Tags
|
||||
v.Metadata = opts.Metadata
|
||||
v.FeatureFlags = opts.getMergedFeatureFlags()
|
||||
// Update route URLs if applicable
|
||||
if s.varzUpdateRouteURLs {
|
||||
v.Cluster.URLs = urlsToStrings(opts.Routes)
|
||||
@@ -3008,15 +3028,43 @@ type MetaSnapshotStats struct {
|
||||
LastDuration time.Duration `json:"last_duration,omitempty"` // LastDuration is how long the last meta snapshot took
|
||||
}
|
||||
|
||||
// metaClusterSnapshotStats returns snapshot statistics for the meta group.
|
||||
func (s *Server) metaClusterSnapshotStats(js *jetStream, mg RaftNode) *MetaSnapshotStats {
|
||||
entries, bytes := mg.Size()
|
||||
snap := &MetaSnapshotStats{
|
||||
PendingEntries: entries,
|
||||
PendingSize: bytes,
|
||||
}
|
||||
|
||||
js.mu.RLock()
|
||||
cluster := js.cluster
|
||||
js.mu.RUnlock()
|
||||
|
||||
if cluster != nil {
|
||||
timeNanos := atomic.LoadInt64(&cluster.lastMetaSnapTime)
|
||||
durationNanos := atomic.LoadInt64(&cluster.lastMetaSnapDuration)
|
||||
if timeNanos > 0 {
|
||||
snap.LastTime = time.Unix(0, timeNanos).UTC()
|
||||
}
|
||||
if durationNanos > 0 {
|
||||
snap.LastDuration = time.Duration(durationNanos)
|
||||
}
|
||||
}
|
||||
|
||||
return snap
|
||||
}
|
||||
|
||||
// MetaClusterInfo shows information about the meta group.
|
||||
type MetaClusterInfo struct {
|
||||
Name string `json:"name,omitempty"` // Name is the name of the cluster
|
||||
Leader string `json:"leader,omitempty"` // Leader is the server name of the cluster leader
|
||||
Peer string `json:"peer,omitempty"` // Peer is unique ID of the leader
|
||||
Replicas []*PeerInfo `json:"replicas,omitempty"` // Replicas is a list of known peers
|
||||
Size int `json:"cluster_size"` // Size is the known size of the cluster
|
||||
Pending int `json:"pending"` // Pending is how many RAFT messages are not yet processed
|
||||
Snapshot *MetaSnapshotStats `json:"snapshot"` // Snapshot contains meta snapshot statistics
|
||||
Name string `json:"name,omitempty"` // Name is the name of the cluster
|
||||
Leader string `json:"leader,omitempty"` // Leader is the server name of the cluster leader
|
||||
Peer string `json:"peer,omitempty"` // Peer is unique ID of the leader
|
||||
Replicas []*PeerInfo `json:"replicas,omitempty"` // Replicas is a list of known peers
|
||||
Size int `json:"cluster_size"` // Size is the known size of the cluster
|
||||
Pending int `json:"pending"` // Pending is how many RAFT messages are not yet processed
|
||||
PendingRequests int `json:"pending_requests"` // PendingRequests is how many CRUD operations are queued for processing
|
||||
PendingInfos int `json:"pending_infos"` // PendingInfos is how many info operations are queued for processing
|
||||
Snapshot *MetaSnapshotStats `json:"snapshot"` // Snapshot contains meta snapshot statistics
|
||||
}
|
||||
|
||||
// JSInfo has detailed information on JetStream.
|
||||
@@ -3233,32 +3281,18 @@ func (s *Server) Jsz(opts *JSzOptions) (*JSInfo, error) {
|
||||
|
||||
if mg := js.getMetaGroup(); mg != nil {
|
||||
if ci := s.raftNodeToClusterInfo(mg); ci != nil {
|
||||
entries, bytes := mg.Size()
|
||||
jsi.Meta = &MetaClusterInfo{Name: ci.Name, Leader: ci.Leader, Peer: getHash(ci.Leader), Size: mg.ClusterSize()}
|
||||
if isLeader {
|
||||
jsi.Meta.Replicas = ci.Replicas
|
||||
}
|
||||
if ipq := s.jsAPIRoutedReqs; ipq != nil {
|
||||
jsi.Meta.Pending = ipq.len()
|
||||
jsi.Meta.PendingRequests = ipq.len()
|
||||
}
|
||||
// Add meta snapshot stats
|
||||
jsi.Meta.Snapshot = &MetaSnapshotStats{
|
||||
PendingEntries: entries,
|
||||
PendingSize: bytes,
|
||||
}
|
||||
js.mu.RLock()
|
||||
cluster := js.cluster
|
||||
js.mu.RUnlock()
|
||||
if cluster != nil {
|
||||
timeNanos := atomic.LoadInt64(&cluster.lastMetaSnapTime)
|
||||
durationNanos := atomic.LoadInt64(&cluster.lastMetaSnapDuration)
|
||||
if timeNanos > 0 {
|
||||
jsi.Meta.Snapshot.LastTime = time.Unix(0, timeNanos).UTC()
|
||||
}
|
||||
if durationNanos > 0 {
|
||||
jsi.Meta.Snapshot.LastDuration = time.Duration(durationNanos)
|
||||
}
|
||||
if ipq := s.jsAPIRoutedInfoReqs; ipq != nil {
|
||||
jsi.Meta.PendingInfos = ipq.len()
|
||||
}
|
||||
jsi.Meta.Pending = jsi.Meta.PendingRequests + jsi.Meta.PendingInfos
|
||||
jsi.Meta.Snapshot = s.metaClusterSnapshotStats(js, mg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3695,6 +3729,20 @@ func (s *Server) healthz(opts *HealthzOptions) *HealthStatus {
|
||||
})
|
||||
continue
|
||||
}
|
||||
if streamWerr := s.getWriteErr(); streamWerr != nil {
|
||||
if !details {
|
||||
health.Status = na
|
||||
health.Error = fmt.Sprintf("JetStream stream '%s > %s' write error: %v", acc, stream, streamWerr)
|
||||
return health
|
||||
}
|
||||
health.Errors = append(health.Errors, HealthzError{
|
||||
Type: HealthzErrorStream,
|
||||
Account: acc.Name,
|
||||
Stream: stream,
|
||||
Error: fmt.Sprintf("JetStream stream '%s > %s' write error: %v", acc, stream, streamWerr),
|
||||
})
|
||||
continue
|
||||
}
|
||||
if streamFound {
|
||||
// if consumer option is passed, verify that the consumer exists on stream
|
||||
if opts.Consumer != _EMPTY_ {
|
||||
@@ -3771,49 +3819,37 @@ func (s *Server) healthz(opts *HealthzOptions) *HealthStatus {
|
||||
meta = cc.meta
|
||||
js.mu.RUnlock()
|
||||
|
||||
// If no meta leader.
|
||||
if meta == nil || meta.GroupLeader() == _EMPTY_ {
|
||||
if !details {
|
||||
health.Status = na
|
||||
health.Error = "JetStream has not established contact with a meta leader"
|
||||
} else {
|
||||
health.Errors = []HealthzError{
|
||||
{
|
||||
Type: HealthzErrorJetStream,
|
||||
Error: "JetStream has not established contact with a meta leader",
|
||||
},
|
||||
}
|
||||
}
|
||||
return health
|
||||
// Check meta layer health.
|
||||
var metaNoLeader, metaClosed, metaUnhealthy bool
|
||||
var metaWerr error
|
||||
if meta != nil {
|
||||
metaNoLeader = meta.GroupLeader() == _EMPTY_
|
||||
metaClosed = meta.State() == Closed
|
||||
metaUnhealthy = !meta.Healthy()
|
||||
metaWerr = meta.GetWriteErr()
|
||||
}
|
||||
|
||||
// If we are not current with the meta leader.
|
||||
if !meta.Healthy() {
|
||||
if !details {
|
||||
health.Status = na
|
||||
health.Error = "JetStream is not current with the meta leader"
|
||||
metaRecovering := js.isMetaRecovering()
|
||||
if meta == nil || metaNoLeader || metaClosed || metaUnhealthy || metaWerr != nil || metaRecovering {
|
||||
var desc string
|
||||
if metaWerr != nil {
|
||||
desc = fmt.Sprintf("JetStream meta layer write error: %v", metaWerr)
|
||||
} else if metaClosed {
|
||||
desc = "JetStream meta layer is not running"
|
||||
} else if meta != nil && metaRecovering {
|
||||
desc = "JetStream is still recovering meta layer"
|
||||
} else if meta == nil || metaNoLeader {
|
||||
desc = "JetStream has not established contact with a meta leader"
|
||||
} else {
|
||||
health.Errors = []HealthzError{
|
||||
{
|
||||
Type: HealthzErrorJetStream,
|
||||
Error: "JetStream is not current with the meta leader",
|
||||
},
|
||||
}
|
||||
desc = "JetStream is not current with the meta leader"
|
||||
}
|
||||
return health
|
||||
}
|
||||
|
||||
// Are we still recovering meta layer?
|
||||
if js.isMetaRecovering() {
|
||||
if !details {
|
||||
health.Status = na
|
||||
health.Error = "JetStream is still recovering meta layer"
|
||||
|
||||
health.Error = desc
|
||||
} else {
|
||||
health.Errors = []HealthzError{
|
||||
{
|
||||
Type: HealthzErrorJetStream,
|
||||
Error: "JetStream is still recovering meta layer",
|
||||
Error: desc,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -4090,6 +4126,8 @@ type RaftzGroup struct {
|
||||
QuorumNeeded int `json:"quorum_needed"`
|
||||
Observer bool `json:"observer,omitempty"`
|
||||
Paused bool `json:"paused,omitempty"`
|
||||
Overrun bool `json:"overrun,omitempty"`
|
||||
OverrunCount uint64 `json:"overrun_count,omitempty"`
|
||||
Committed uint64 `json:"committed"`
|
||||
Applied uint64 `json:"applied"`
|
||||
CatchingUp bool `json:"catching_up,omitempty"`
|
||||
@@ -4198,6 +4236,8 @@ func (s *Server) Raftz(opts *RaftzOptions) *RaftzStatus {
|
||||
QuorumNeeded: n.qn,
|
||||
Observer: n.observer,
|
||||
Paused: n.paused,
|
||||
Overrun: n.quorumPaused || n.isLeaderOverrun(),
|
||||
OverrunCount: n.overrunCount,
|
||||
Committed: n.commit,
|
||||
Applied: n.applied,
|
||||
CatchingUp: n.catchup != nil,
|
||||
|
||||
+78
-85
@@ -15,7 +15,6 @@ package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
@@ -33,6 +32,8 @@ import (
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/nats-io/jwt/v2"
|
||||
"github.com/nats-io/nats-server/v2/server/gsl"
|
||||
"github.com/nats-io/nats-server/v2/server/stree"
|
||||
"github.com/nats-io/nuid"
|
||||
)
|
||||
|
||||
@@ -258,14 +259,13 @@ type mqttSessionManager struct {
|
||||
|
||||
type mqttAccountSessionManager struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*mqttSession // key is MQTT client ID
|
||||
sessByHash map[string]*mqttSession // key is MQTT client ID hash
|
||||
sessLocked map[string]struct{} // key is MQTT client ID and indicate that a session can not be taken by a new client at this time
|
||||
flappers map[string]time.Time // When connection connects with client ID already in use
|
||||
flapTimer *time.Timer // Timer to perform some cleanup of the flappers map
|
||||
sl *Sublist // sublist allowing to find retained messages for given subscription
|
||||
retmsgs map[string]*mqttRetainedMsgRef // retained messages
|
||||
rmsCache *sync.Map // map[subject]mqttRetainedMsg
|
||||
sessions map[string]*mqttSession // key is MQTT client ID
|
||||
sessByHash map[string]*mqttSession // key is MQTT client ID hash
|
||||
sessLocked map[string]struct{} // key is MQTT client ID and indicate that a session can not be taken by a new client at this time
|
||||
flappers map[string]time.Time // When connection connects with client ID already in use
|
||||
flapTimer *time.Timer // Timer to perform some cleanup of the flappers map
|
||||
retmsgs *stree.SubjectTree[mqttRetainedMsgRef] // retained message metadata
|
||||
rmsCache *sync.Map // map[subject]mqttRetainedMsg
|
||||
jsa mqttJSA
|
||||
domainTk string // Domain (with trailing "."), or possibly empty. This is added to session subject.
|
||||
}
|
||||
@@ -366,7 +366,6 @@ type mqttRetainedMsg struct {
|
||||
|
||||
type mqttRetainedMsgRef struct {
|
||||
sseq uint64
|
||||
sub *subscription
|
||||
}
|
||||
|
||||
// mqttSub contains fields associated with a MQTT subscription, and is added to
|
||||
@@ -2022,10 +2021,14 @@ func (as *mqttAccountSessionManager) processRetainedMsg(_ *subscription, c *clie
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if strings.IndexByte(rm.Subject, 0x7f) >= 0 {
|
||||
c.Warnf("Skipping retained message for subject %q: unsupported character 0x7f", rm.Subject)
|
||||
return
|
||||
}
|
||||
// The as.jsa.id is immutable, so no need to have a rlock here.
|
||||
local := rm.Origin == as.jsa.id
|
||||
// Get the stream sequence for this message.
|
||||
seq, _, _ := ackReplyInfo(reply)
|
||||
seq, _, _, _, _ := ackReplyInfo(reply)
|
||||
if len(m) == 0 {
|
||||
// An empty payload means that we need to remove the retained message.
|
||||
rmSeq := as.removeRetainedMsg(rm.Subject, 0)
|
||||
@@ -2042,7 +2045,7 @@ func (as *mqttAccountSessionManager) processRetainedMsg(_ *subscription, c *clie
|
||||
// Add this retained message. The `rm.Msg` references some buffer that we
|
||||
// don't own. But addRetainedMsg() will take care of making a copy of
|
||||
// `rm.Msg` it `rm` ends-up being stored in the cache.
|
||||
as.addRetainedMsg(rm.Subject, &mqttRetainedMsgRef{sseq: seq}, rm)
|
||||
as.addRetainedMsg(rm.Subject, seq, rm)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2310,17 +2313,16 @@ func (as *mqttAccountSessionManager) sendJSAPIrequests(s *Server, c *client, acc
|
||||
// If a message for this topic already existed, the existing record is updated
|
||||
// with the provided information.
|
||||
// Lock not held on entry.
|
||||
func (as *mqttAccountSessionManager) addRetainedMsg(key string, rf *mqttRetainedMsgRef, rm *mqttRetainedMsg) {
|
||||
func (as *mqttAccountSessionManager) addRetainedMsg(key string, sseq uint64, rm *mqttRetainedMsg) {
|
||||
as.mu.Lock()
|
||||
defer as.mu.Unlock()
|
||||
if as.retmsgs == nil {
|
||||
as.retmsgs = make(map[string]*mqttRetainedMsgRef)
|
||||
as.sl = NewSublistWithCache()
|
||||
as.retmsgs = stree.NewSubjectTree[mqttRetainedMsgRef]()
|
||||
} else {
|
||||
// Check if we already had one retained message. If so, update the existing one.
|
||||
if erf, exists := as.retmsgs[key]; exists {
|
||||
if erf, exists := as.retmsgs.Find(stringToBytes(key)); exists {
|
||||
// Update the stream sequence with the new value.
|
||||
erf.sseq = rf.sseq
|
||||
erf.sseq = sseq
|
||||
// Update the in-memory retained message cache but only for messages
|
||||
// that are already in the cache, i.e. have been (recently) used.
|
||||
// If that is the case, we ask setCachedRetainedMsg() to make a copy
|
||||
@@ -2329,9 +2331,7 @@ func (as *mqttAccountSessionManager) addRetainedMsg(key string, rf *mqttRetained
|
||||
return
|
||||
}
|
||||
}
|
||||
rf.sub = &subscription{subject: []byte(key)}
|
||||
as.retmsgs[key] = rf
|
||||
as.sl.Insert(rf.sub)
|
||||
as.retmsgs.Insert([]byte(key), mqttRetainedMsgRef{sseq: sseq})
|
||||
}
|
||||
|
||||
// Remove the retained message stored with the `subject` key from the map/cache.
|
||||
@@ -2348,15 +2348,13 @@ func (as *mqttAccountSessionManager) addRetainedMsg(key string, rf *mqttRetained
|
||||
func (as *mqttAccountSessionManager) removeRetainedMsg(subject string, seq uint64) uint64 {
|
||||
as.mu.Lock()
|
||||
defer as.mu.Unlock()
|
||||
rm, ok := as.retmsgs[subject]
|
||||
rm, ok := as.retmsgs.Find(stringToBytes(subject))
|
||||
if !ok || (seq > 0 && rm.sseq != seq) {
|
||||
return 0
|
||||
}
|
||||
seq = rm.sseq
|
||||
rm, _ = as.retmsgs.Delete(stringToBytes(subject))
|
||||
as.rmsCache.Delete(subject)
|
||||
delete(as.retmsgs, subject)
|
||||
as.sl.Remove(rm.sub)
|
||||
return seq
|
||||
return rm.sseq
|
||||
}
|
||||
|
||||
// First check if this session's client ID is already in the "locked" map,
|
||||
@@ -2684,27 +2682,22 @@ func (as *mqttAccountSessionManager) processSubs(sess *mqttSession, c *client,
|
||||
// Account session manager lock held on entry.
|
||||
// Session lock held on entry.
|
||||
func (as *mqttAccountSessionManager) serializeRetainedMsgsForSub(rms map[string]*mqttRetainedMsg, sess *mqttSession, c *client, sub *subscription, trace bool) {
|
||||
if len(as.retmsgs) == 0 || len(rms) == 0 {
|
||||
return
|
||||
}
|
||||
result := as.sl.ReverseMatch(string(sub.subject))
|
||||
if len(result.psubs) == 0 {
|
||||
if as.retmsgs.Size() == 0 || len(rms) == 0 {
|
||||
return
|
||||
}
|
||||
toTrace := []mqttPublish{}
|
||||
for _, psub := range result.psubs {
|
||||
|
||||
rm := rms[string(psub.subject)]
|
||||
as.retmsgs.Match(sub.subject, func(subj []byte, _ *mqttRetainedMsgRef) {
|
||||
rm := rms[string(subj)]
|
||||
if rm == nil {
|
||||
// This should not happen since we pre-load messages into rms before
|
||||
// calling serialize.
|
||||
continue
|
||||
return
|
||||
}
|
||||
var pi uint16
|
||||
qos := min(mqttGetQoS(rm.Flags), sub.mqtt.qos)
|
||||
if c.mqtt.rejectQoS2Pub && qos == 2 {
|
||||
c.Warnf("Rejecting retained message with QoS2 for subscription %q, as configured", sub.subject)
|
||||
continue
|
||||
return
|
||||
}
|
||||
if qos > 0 {
|
||||
pi = sess.trackPublishRetained()
|
||||
@@ -2731,7 +2724,7 @@ func (as *mqttAccountSessionManager) serializeRetainedMsgsForSub(rms map[string]
|
||||
sz: len(rm.Msg),
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
for _, pp := range toTrace {
|
||||
c.traceOutOp("PUBLISH", []byte(mqttPubTrace(&pp)))
|
||||
}
|
||||
@@ -2743,27 +2736,21 @@ func (as *mqttAccountSessionManager) serializeRetainedMsgsForSub(rms map[string]
|
||||
// Account session manager NOT lock held on entry.
|
||||
func (as *mqttAccountSessionManager) addRetainedSubjectsForSubject(list map[string]uint64, topSubject string) {
|
||||
as.mu.RLock()
|
||||
if len(as.retmsgs) == 0 {
|
||||
as.mu.RUnlock()
|
||||
defer as.mu.RUnlock()
|
||||
|
||||
if as.retmsgs.Size() == 0 {
|
||||
return
|
||||
}
|
||||
result := as.sl.ReverseMatch(topSubject)
|
||||
as.mu.RUnlock()
|
||||
|
||||
for _, sub := range result.psubs {
|
||||
if _, ok := list[string(sub.subject)]; ok {
|
||||
continue
|
||||
as.retmsgs.Match(stringToBytes(topSubject), func(subj []byte, ret *mqttRetainedMsgRef) {
|
||||
subject := string(subj)
|
||||
if _, ok := list[subject]; ok {
|
||||
return
|
||||
}
|
||||
var seq uint64
|
||||
as.mu.RLock()
|
||||
if rm, ok := as.retmsgs[string(sub.subject)]; ok {
|
||||
seq = rm.sseq
|
||||
if seq := ret.sseq; seq > 0 {
|
||||
list[subject] = seq
|
||||
}
|
||||
as.mu.RUnlock()
|
||||
if seq > 0 {
|
||||
list[string(sub.subject)] = seq
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type warner interface {
|
||||
@@ -3457,7 +3444,7 @@ func (sess *mqttSession) trackPublish(jsDur, jsAckSubject string) (uint16, bool)
|
||||
}
|
||||
|
||||
// Get the stream sequence and duplicate flag from the ack reply subject.
|
||||
sseq, _, dcount := ackReplyInfo(jsAckSubject)
|
||||
sseq, _, dcount, _, _ := ackReplyInfo(jsAckSubject)
|
||||
if dcount > 1 {
|
||||
dup = true
|
||||
}
|
||||
@@ -3561,7 +3548,7 @@ func (sess *mqttSession) trackAsPubRel(pi uint16, jsAckSubject string) {
|
||||
return
|
||||
}
|
||||
|
||||
sseq, _, _ := ackReplyInfo(jsAckSubject)
|
||||
sseq, _, _, _, _ := ackReplyInfo(jsAckSubject)
|
||||
|
||||
if sess.cpending == nil {
|
||||
sess.cpending = make(map[string]map[uint64]uint16)
|
||||
@@ -4568,13 +4555,8 @@ func (s *Server) mqttCheckPubRetainedPerms() {
|
||||
}
|
||||
sm.mu.RUnlock()
|
||||
|
||||
type retainedMsg struct {
|
||||
subj string
|
||||
rmsg *mqttRetainedMsgRef
|
||||
}
|
||||
|
||||
// For each session we will obtain a list of retained messages.
|
||||
var _rms [128]retainedMsg
|
||||
var _rms [128]uint64
|
||||
rms := _rms[:0]
|
||||
for _, asm := range asms {
|
||||
// Get all of the retained messages. Then we will sort them so
|
||||
@@ -4582,19 +4564,20 @@ func (s *Server) mqttCheckPubRetainedPerms() {
|
||||
// store to not have to load out-of-order blocks so often.
|
||||
asm.mu.RLock()
|
||||
rms = rms[:0] // reuse slice
|
||||
for subj, rf := range asm.retmsgs {
|
||||
rms = append(rms, retainedMsg{
|
||||
subj: subj,
|
||||
rmsg: rf,
|
||||
})
|
||||
}
|
||||
// Copy the sequence out of the tree. The tree entry itself can be
|
||||
// updated concurrently by addRetainedMsg() after we release the lock,
|
||||
// so keeping a pointer here would race with the later sort.
|
||||
asm.retmsgs.IterOrdered(func(_ []byte, rm *mqttRetainedMsgRef) bool {
|
||||
rms = append(rms, rm.sseq)
|
||||
return true
|
||||
})
|
||||
jsaID := asm.jsa.id
|
||||
asm.mu.RUnlock()
|
||||
slices.SortFunc(rms, func(i, j retainedMsg) int { return cmp.Compare(i.rmsg.sseq, j.rmsg.sseq) })
|
||||
slices.Sort(rms)
|
||||
|
||||
perms := map[string]*perm{}
|
||||
perms := map[string]*mqttPerm{}
|
||||
for _, rf := range rms {
|
||||
jsm, err := asm.jsa.loadMsg(mqttRetainedMsgsStreamName, rf.rmsg.sseq)
|
||||
jsm, err := asm.jsa.loadMsg(mqttRetainedMsgsStreamName, rf)
|
||||
if err != nil || jsm == nil {
|
||||
continue
|
||||
}
|
||||
@@ -4617,7 +4600,7 @@ func (s *Server) mqttCheckPubRetainedPerms() {
|
||||
}
|
||||
// If there is permission and no longer allowed to publish in
|
||||
// the subject, remove the publish retained message from the map.
|
||||
if p != nil && !pubAllowed(p, rf.subj) {
|
||||
if p != nil && !pubAllowed(p, rm.Subject) {
|
||||
u = nil
|
||||
}
|
||||
}
|
||||
@@ -4636,7 +4619,12 @@ func (s *Server) mqttCheckPubRetainedPerms() {
|
||||
}
|
||||
|
||||
// Helper to generate only pub permissions from a Permissions object
|
||||
func generatePubPerms(perms *Permissions) *perm {
|
||||
type mqttPerm struct {
|
||||
allow *gsl.SimpleSublist
|
||||
deny *gsl.SimpleSublist
|
||||
}
|
||||
|
||||
func generatePubPerms(perms *Permissions) *mqttPerm {
|
||||
// If given permissions is `nil`, then it means that permissions block
|
||||
// has been removed (so the user is now allowed to publish on everything)
|
||||
// or was never there in the first place. Returning `nil` will let the
|
||||
@@ -4644,39 +4632,38 @@ func generatePubPerms(perms *Permissions) *perm {
|
||||
if perms == nil {
|
||||
return nil
|
||||
}
|
||||
var p *perm
|
||||
var p *mqttPerm
|
||||
if perms.Publish.Allow != nil {
|
||||
p = &perm{}
|
||||
p.allow = NewSublistWithCache()
|
||||
p = &mqttPerm{}
|
||||
p.allow = gsl.NewSimpleSublist()
|
||||
for _, pubSubject := range perms.Publish.Allow {
|
||||
sub := &subscription{subject: []byte(pubSubject)}
|
||||
p.allow.Insert(sub)
|
||||
_ = p.allow.Insert(pubSubject, struct{}{})
|
||||
}
|
||||
}
|
||||
if len(perms.Publish.Deny) > 0 {
|
||||
if p == nil {
|
||||
p = &perm{}
|
||||
p = &mqttPerm{}
|
||||
}
|
||||
p.deny = NewSublistWithCache()
|
||||
p.deny = gsl.NewSimpleSublist()
|
||||
for _, pubSubject := range perms.Publish.Deny {
|
||||
sub := &subscription{subject: []byte(pubSubject)}
|
||||
p.deny.Insert(sub)
|
||||
_ = p.deny.Insert(pubSubject, struct{}{})
|
||||
}
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// Helper that checks if given `perms` allow to publish on the given `subject`
|
||||
func pubAllowed(perms *perm, subject string) bool {
|
||||
func pubAllowed(perms *mqttPerm, subject string) bool {
|
||||
if perms == nil {
|
||||
return true
|
||||
}
|
||||
allowed := true
|
||||
if perms.allow != nil {
|
||||
np, _ := perms.allow.NumInterest(subject)
|
||||
allowed = np != 0
|
||||
allowed = perms.allow.HasInterest(subject)
|
||||
}
|
||||
// If we have a deny list and are currently allowed, check that as well.
|
||||
if allowed && perms.deny != nil {
|
||||
np, _ := perms.deny.NumInterest(subject)
|
||||
allowed = np == 0
|
||||
allowed = !perms.deny.HasInterest(subject)
|
||||
}
|
||||
return allowed
|
||||
}
|
||||
@@ -5729,6 +5716,12 @@ func mqttToNATSSubjectConversion(mt []byte, wcOk bool) ([]byte, error) {
|
||||
case ' ':
|
||||
// As of now, we cannot support ' ' in the MQTT topic/filter.
|
||||
return nil, errMQTTUnsupportedCharacters
|
||||
case 0x7f:
|
||||
// SubjectTree uses DEL as an internal pivot marker, so retained
|
||||
// subjects containing it cannot be indexed safely, including
|
||||
// legacy retained messages recovered from the retained-message
|
||||
// stream.
|
||||
return nil, errMQTTUnsupportedCharacters
|
||||
case btsep:
|
||||
if !cp {
|
||||
makeCopy(i)
|
||||
|
||||
+43
-90
@@ -1,4 +1,4 @@
|
||||
// Copyright 2024-2025 The NATS Authors
|
||||
// Copyright 2024-2026 The NATS Authors
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
|
||||
const (
|
||||
MsgTraceDest = "Nats-Trace-Dest"
|
||||
MsgTraceDestDisabled = "trace disabled" // This must be an invalid NATS subject
|
||||
MsgTraceHop = "Nats-Trace-Hop"
|
||||
MsgTraceOriginAccount = "Nats-Trace-Origin-Account"
|
||||
MsgTraceOnly = "Nats-Trace-Only"
|
||||
@@ -33,10 +34,19 @@ const (
|
||||
// External trace header. Note that this header is normally in lower
|
||||
// case (https://www.w3.org/TR/trace-context/#header-name). Vendors
|
||||
// MUST expect the header in any case (upper, lower, mixed), and
|
||||
// SHOULD send the header name in lowercase.
|
||||
// SHOULD send the header name in lowercase. We used to change it
|
||||
// to lower case, but no longer do that in 2.14.
|
||||
traceParentHdr = "traceparent"
|
||||
)
|
||||
|
||||
var (
|
||||
traceDestHdrAsBytes = stringToBytes(MsgTraceDest)
|
||||
traceDestDisabledAsBytes = stringToBytes(MsgTraceDestDisabled)
|
||||
traceParentHdrAsBytes = stringToBytes(traceParentHdr)
|
||||
crLFAsBytes = stringToBytes(CR_LF)
|
||||
dashAsBytes = stringToBytes("-")
|
||||
)
|
||||
|
||||
type MsgTraceType string
|
||||
|
||||
// Type of message trace events in the MsgTraceEvents list.
|
||||
@@ -352,7 +362,6 @@ func (c *client) initMsgTrace() *msgTrace {
|
||||
}
|
||||
return vv[0]
|
||||
}
|
||||
ct := getCompressionType(getHdrVal(acceptEncodingHeader))
|
||||
var (
|
||||
dest string
|
||||
traceOnly bool
|
||||
@@ -454,9 +463,9 @@ func (c *client) initMsgTrace() *msgTrace {
|
||||
}
|
||||
// Check sampling, but only from origin server.
|
||||
if c.kind == CLIENT && !sample(sampling) {
|
||||
// Need to desactivate the traceParentHdr so that if the message
|
||||
// is routed, it does possibly trigger a trace there.
|
||||
disableTraceHeaders(c, hdr)
|
||||
// Need to disable tracing so that if the message is routed, it won't
|
||||
// trigger a trace there.
|
||||
c.msgBuf = c.setHeader(MsgTraceDest, MsgTraceDestDisabled, c.msgBuf)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -465,7 +474,7 @@ func (c *client) initMsgTrace() *msgTrace {
|
||||
acc: acc,
|
||||
oan: oan,
|
||||
dest: dest,
|
||||
ct: ct,
|
||||
ct: getCompressionType(getHdrVal(acceptEncodingHeader)),
|
||||
hop: hop,
|
||||
event: &MsgTraceEvent{
|
||||
Request: MsgTraceRequest{
|
||||
@@ -503,9 +512,7 @@ func sample(sampling int) bool {
|
||||
// the headers have been lifted due to the presence of the external trace header
|
||||
// only.
|
||||
// Note that because of the traceParentHdr, the search is done in a case
|
||||
// insensitive way, but if the header is found, it is rewritten in lower case
|
||||
// as suggested by the spec, but also to make it easier to disable the header
|
||||
// when needed.
|
||||
// insensitive way. We used to rewrite it in lower case but no longer do since v2.14.
|
||||
func genHeaderMapIfTraceHeadersPresent(hdr []byte) (map[string][]string, bool) {
|
||||
|
||||
var (
|
||||
@@ -520,11 +527,6 @@ func genHeaderMapIfTraceHeadersPresent(hdr []byte) (map[string][]string, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
traceDestHdrAsBytes := stringToBytes(MsgTraceDest)
|
||||
traceParentHdrAsBytes := stringToBytes(traceParentHdr)
|
||||
crLFAsBytes := stringToBytes(CR_LF)
|
||||
dashAsBytes := stringToBytes("-")
|
||||
|
||||
keys := _keys[:0]
|
||||
vals := _vals[:0]
|
||||
|
||||
@@ -537,46 +539,50 @@ func genHeaderMapIfTraceHeadersPresent(hdr []byte) (map[string][]string, bool) {
|
||||
keyStart := i
|
||||
key := hdr[keyStart : keyStart+del]
|
||||
i += del + 1
|
||||
for i < len(hdr) && (hdr[i] == ' ' || hdr[i] == '\t') {
|
||||
i++
|
||||
}
|
||||
valStart := i
|
||||
nl := bytes.Index(hdr[valStart:], crLFAsBytes)
|
||||
if nl < 0 {
|
||||
break
|
||||
}
|
||||
if len(key) > 0 {
|
||||
val := bytes.Trim(hdr[valStart:valStart+nl], " \t")
|
||||
valEnd := valStart + nl
|
||||
for valEnd > valStart && (hdr[valEnd-1] == ' ' || hdr[valEnd-1] == '\t') {
|
||||
valEnd--
|
||||
}
|
||||
val := hdr[valStart:valEnd]
|
||||
if len(key) > 0 && len(val) > 0 {
|
||||
vals = append(vals, val)
|
||||
|
||||
// We search for our special keys only if not already found.
|
||||
|
||||
// Check for the external trace header.
|
||||
if bytes.EqualFold(key, traceParentHdrAsBytes) {
|
||||
// Rewrite the header using lower case if needed.
|
||||
if !bytes.Equal(key, traceParentHdrAsBytes) {
|
||||
copy(hdr[keyStart:], traceParentHdrAsBytes)
|
||||
}
|
||||
// Search needs to be case insensitive.
|
||||
if !traceParentHdrFound && bytes.EqualFold(key, traceParentHdrAsBytes) {
|
||||
// We will now check if the value has sampling or not.
|
||||
// TODO(ik): Not sure if this header can have multiple values
|
||||
// or not, and if so, what would be the rule to check for
|
||||
// sampling. What is done here is to check them all until we
|
||||
// found one with sampling.
|
||||
if !traceParentHdrFound {
|
||||
tk := bytes.Split(val, dashAsBytes)
|
||||
if len(tk) == 4 && len([]byte(tk[3])) == 2 {
|
||||
if hexVal, err := strconv.ParseInt(bytesToString(tk[3]), 16, 8); err == nil {
|
||||
if hexVal&0x1 == 0x1 {
|
||||
traceParentHdrFound = true
|
||||
}
|
||||
tk := bytes.Split(val, dashAsBytes)
|
||||
if len(tk) == 4 && len([]byte(tk[3])) == 2 {
|
||||
if hexVal, err := strconv.ParseInt(bytesToString(tk[3]), 16, 8); err == nil {
|
||||
if hexVal&0x1 == 0x1 {
|
||||
traceParentHdrFound = true
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add to the keys with the external trace header in lower case.
|
||||
keys = append(keys, traceParentHdrAsBytes)
|
||||
} else {
|
||||
// Is the key the Nats-Trace-Dest header?
|
||||
if bytes.EqualFold(key, traceDestHdrAsBytes) {
|
||||
traceDestHdrFound = true
|
||||
} else if !traceDestHdrFound && bytes.Equal(key, traceDestHdrAsBytes) {
|
||||
// This is the Nats-Trace-Dest header, check the value to see
|
||||
// if it indicates that the trace was disabled.
|
||||
if bytes.Equal(val, traceDestDisabledAsBytes) {
|
||||
return nil, false
|
||||
}
|
||||
// Add to the keys and preserve the key's case
|
||||
keys = append(keys, key)
|
||||
traceDestHdrFound = true
|
||||
}
|
||||
// Add to the keys and preserve the key's case
|
||||
keys = append(keys, key)
|
||||
}
|
||||
i += nl + 2
|
||||
}
|
||||
@@ -655,59 +661,6 @@ func (t *msgTrace) setHopHeader(c *client, msg []byte) []byte {
|
||||
return c.setHeader(MsgTraceHop, t.nhop, msg)
|
||||
}
|
||||
|
||||
// Will look for the MsgTraceSendTo and traceParentHdr headers and change the first
|
||||
// character to an 'X' so that if this message is sent to a remote, the remote
|
||||
// will not initialize tracing since it won't find the actual trace headers.
|
||||
// The function returns the position of the headers so it can efficiently be
|
||||
// re-enabled by calling enableTraceHeaders.
|
||||
// Note that if `msg` can be either the header alone or the full message
|
||||
// (header and payload). This function will use c.pa.hdr to limit the
|
||||
// search to the header section alone.
|
||||
func disableTraceHeaders(c *client, msg []byte) []int {
|
||||
// Code largely copied from getHeader(), except that we don't need the value
|
||||
if c.pa.hdr <= 0 {
|
||||
return []int{-1, -1}
|
||||
}
|
||||
hdr := msg[:c.pa.hdr]
|
||||
headers := [2]string{MsgTraceDest, traceParentHdr}
|
||||
positions := [2]int{-1, -1}
|
||||
for i := 0; i < 2; i++ {
|
||||
key := stringToBytes(headers[i])
|
||||
pos := bytes.Index(hdr, key)
|
||||
if pos < 0 {
|
||||
continue
|
||||
}
|
||||
// Make sure this key does not have additional prefix.
|
||||
if pos < 2 || hdr[pos-1] != '\n' || hdr[pos-2] != '\r' {
|
||||
continue
|
||||
}
|
||||
index := pos + len(key)
|
||||
if index >= len(hdr) {
|
||||
continue
|
||||
}
|
||||
if hdr[index] != ':' {
|
||||
continue
|
||||
}
|
||||
// Disable the trace by altering the first character of the header
|
||||
hdr[pos] = 'X'
|
||||
positions[i] = pos
|
||||
}
|
||||
// Return the positions of those characters so we can re-enable the headers.
|
||||
return positions[:2]
|
||||
}
|
||||
|
||||
// Changes back the character at the given position `pos` in the `msg`
|
||||
// byte slice to the first character of the MsgTraceSendTo header.
|
||||
func enableTraceHeaders(msg []byte, positions []int) {
|
||||
firstChar := [2]byte{MsgTraceDest[0], traceParentHdr[0]}
|
||||
for i, pos := range positions {
|
||||
if pos == -1 {
|
||||
continue
|
||||
}
|
||||
msg[pos] = firstChar[i]
|
||||
}
|
||||
}
|
||||
|
||||
func (t *msgTrace) setIngressError(err string) {
|
||||
if i := t.event.Ingress(); i != nil {
|
||||
i.Error = err
|
||||
|
||||
+165
-7
@@ -27,6 +27,7 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@@ -108,6 +109,34 @@ type CompressionOpts struct {
|
||||
RTTThresholds []time.Duration
|
||||
}
|
||||
|
||||
func (c1 *CompressionOpts) equals(c2 *CompressionOpts) bool {
|
||||
if c1 == c2 {
|
||||
return true
|
||||
}
|
||||
if (c1 == nil && c2 != nil) || (c1 != nil && c2 == nil) {
|
||||
return false
|
||||
}
|
||||
if c1.Mode != c2.Mode {
|
||||
return false
|
||||
}
|
||||
// For s2_auto, if one has an empty RTTThresholds, it is equivalent
|
||||
// to the defaultCompressionS2AutoRTTThresholds array, so compare with that.
|
||||
if c1.Mode == CompressionS2Auto {
|
||||
rtts1 := c1.RTTThresholds
|
||||
if len(rtts1) == 0 {
|
||||
rtts1 = defaultCompressionS2AutoRTTThresholds
|
||||
}
|
||||
rtts2 := c2.RTTThresholds
|
||||
if len(rtts2) == 0 {
|
||||
rtts2 = defaultCompressionS2AutoRTTThresholds
|
||||
}
|
||||
if !reflect.DeepEqual(rtts1, rtts2) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// GatewayOpts are options for gateways.
|
||||
// NOTE: This structure is no longer used for monitoring endpoints
|
||||
// and json tags are deprecated and may be removed in the future.
|
||||
@@ -283,6 +312,48 @@ type RemoteLeafOpts struct {
|
||||
// existing connection will be closed and not solicited again (until it is changed
|
||||
// to `false` again.
|
||||
Disabled bool `json:"-"`
|
||||
|
||||
// If this is set to true, this remote will ignore any server leafnode URLs
|
||||
// returned by the hub, allowing the user to fully manage the servers this
|
||||
// remote can connect to.
|
||||
IgnoreDiscoveredServers bool `json:"-"`
|
||||
}
|
||||
|
||||
// Returns a string representation of this `RemoteLeafOpts` object, containing
|
||||
// the URLs (unredacted), the account (or "$G" if none is specified) and, if present,
|
||||
// the credentials filename.
|
||||
func (r *RemoteLeafOpts) name() string {
|
||||
return generateRemoteLeafOptsName(r, false)
|
||||
}
|
||||
|
||||
// Same than RemoteLeafOpts.name() but uses redacted URLs. This is to be used for logging.
|
||||
func (r *RemoteLeafOpts) safeName() string {
|
||||
return generateRemoteLeafOptsName(r, true)
|
||||
}
|
||||
|
||||
func generateRemoteLeafOptsName(r *RemoteLeafOpts, redacted bool) string {
|
||||
acc := r.LocalAccount
|
||||
if acc == _EMPTY_ {
|
||||
acc = globalAccountName
|
||||
}
|
||||
var optional string
|
||||
// There could be Credentials or NKey, not both (would be caught as a misconfig)
|
||||
if c := r.Credentials; c != _EMPTY_ {
|
||||
optional = fmt.Sprintf(", credentials=%q", c)
|
||||
} else if nk := r.Nkey; nk != _EMPTY_ {
|
||||
if redacted {
|
||||
optional = ", nkey=\"[REDACTED]\""
|
||||
} else {
|
||||
optional = fmt.Sprintf(", nkey=%q", nk)
|
||||
}
|
||||
}
|
||||
var urls []*url.URL
|
||||
if redacted {
|
||||
urls = redactURLList(r.URLs)
|
||||
} else {
|
||||
urls = r.URLs
|
||||
}
|
||||
return fmt.Sprintf("urls=%q, account=%q%s", urls, acc, optional)
|
||||
}
|
||||
|
||||
// JSLimitOpts are active limits for the meta cluster
|
||||
@@ -387,6 +458,7 @@ type Options struct {
|
||||
JetStreamTpm JSTpmOpts
|
||||
JetStreamMaxCatchup int64
|
||||
JetStreamRequestQueueLimit int64
|
||||
JetStreamInfoQueueLimit int64
|
||||
JetStreamMetaCompact uint64
|
||||
JetStreamMetaCompactSize uint64
|
||||
JetStreamMetaCompactSync bool
|
||||
@@ -478,6 +550,9 @@ type Options struct {
|
||||
// Metadata describing the server. They will be included in 'Z' responses.
|
||||
Metadata map[string]string `json:"-"`
|
||||
|
||||
// FeatureFlags the server opts-in to (or opts-out of). They will be included in 'Z' responses.
|
||||
FeatureFlags map[string]bool `json:"-"`
|
||||
|
||||
// OCSPConfig enables OCSP Stapling in the server.
|
||||
OCSPConfig *OCSPConfig
|
||||
tlsConfigOpts *TLSConfigOpts
|
||||
@@ -1748,6 +1823,29 @@ func (o *Options) processConfigFileLine(k string, v any, errors *[]error, warnin
|
||||
*errors = append(*errors, err)
|
||||
return
|
||||
}
|
||||
case "feature_flags":
|
||||
var err error
|
||||
switch v := v.(type) {
|
||||
case map[string]any:
|
||||
for mk, mv := range v {
|
||||
tk, mv = unwrapValue(mv, <)
|
||||
b, ok := mv.(bool)
|
||||
if !ok {
|
||||
err = &configErr{tk, fmt.Sprintf("error parsing feature flag %q: expected bool, got %T", mk, mv)}
|
||||
break
|
||||
}
|
||||
if o.FeatureFlags == nil {
|
||||
o.FeatureFlags = make(map[string]bool)
|
||||
}
|
||||
o.FeatureFlags[mk] = b
|
||||
}
|
||||
default:
|
||||
err = &configErr{tk, fmt.Sprintf("error parsing feature flags: unsupported type %T", v)}
|
||||
}
|
||||
if err != nil {
|
||||
*errors = append(*errors, err)
|
||||
return
|
||||
}
|
||||
case "default_js_domain":
|
||||
vv, ok := v.(map[string]any)
|
||||
if !ok {
|
||||
@@ -2641,6 +2739,12 @@ func parseJetStream(v any, opts *Options, errors *[]error, warnings *[]error) er
|
||||
return &configErr{tk, fmt.Sprintf("Expected a parseable size for %q, got %v", mk, mv)}
|
||||
}
|
||||
opts.JetStreamRequestQueueLimit = lim
|
||||
case "info_queue_limit":
|
||||
lim, ok := mv.(int64)
|
||||
if !ok {
|
||||
return &configErr{tk, fmt.Sprintf("Expected a parseable size for %q, got %v", mk, mv)}
|
||||
}
|
||||
opts.JetStreamInfoQueueLimit = lim
|
||||
case "meta_compact":
|
||||
thres, ok := mv.(int64)
|
||||
if !ok || thres < 0 {
|
||||
@@ -2936,6 +3040,7 @@ func parseRemoteLeafNodes(v any, errors *[]error, warnings *[]error) ([]*RemoteL
|
||||
if !ok {
|
||||
return nil, &configErr{tk, fmt.Sprintf("Expected remotes field to be an array, got %T", v)}
|
||||
}
|
||||
names := make(map[string]struct{})
|
||||
remotes := make([]*RemoteLeafOpts, 0, len(ra))
|
||||
for _, r := range ra {
|
||||
tk, r = unwrapValue(r, <)
|
||||
@@ -3105,6 +3210,8 @@ func parseRemoteLeafNodes(v any, errors *[]error, warnings *[]error) ([]*RemoteL
|
||||
}
|
||||
}
|
||||
}
|
||||
case "ignore_discovered_servers":
|
||||
remote.IgnoreDiscoveredServers = v.(bool)
|
||||
default:
|
||||
if !tk.IsUsedVariable() {
|
||||
err := &unknownConfigFieldErr{
|
||||
@@ -3128,6 +3235,12 @@ func parseRemoteLeafNodes(v any, errors *[]error, warnings *[]error) ([]*RemoteL
|
||||
*warnings = append(*warnings, &configErr{proxyToken, warn})
|
||||
}
|
||||
}
|
||||
rn := remote.name()
|
||||
if _, dup := names[rn]; dup {
|
||||
*errors = append(*errors, &configErr{tk, fmt.Sprintf("duplicate remote %s", remote.safeName())})
|
||||
continue
|
||||
}
|
||||
names[rn] = struct{}{}
|
||||
remotes = append(remotes, remote)
|
||||
}
|
||||
return remotes, nil
|
||||
@@ -6006,6 +6119,9 @@ func setBaselineOptions(opts *Options) {
|
||||
if opts.JetStreamRequestQueueLimit <= 0 {
|
||||
opts.JetStreamRequestQueueLimit = JSDefaultRequestQueueLimit
|
||||
}
|
||||
if opts.JetStreamInfoQueueLimit <= 0 {
|
||||
opts.JetStreamInfoQueueLimit = opts.JetStreamRequestQueueLimit
|
||||
}
|
||||
}
|
||||
|
||||
func getDefaultAuthTimeout(tls *tls.Config, tlsTimeout float64) float64 {
|
||||
@@ -6439,14 +6555,56 @@ func expandPath(p string) (string, error) {
|
||||
// RedactArgs redacts sensitive arguments from the command line.
|
||||
// For example, turns '--pass=secret' into '--pass=[REDACTED]'.
|
||||
func RedactArgs(args []string) {
|
||||
secret := regexp.MustCompile("^-{1,2}(user|pass|auth)(=.*)?$")
|
||||
secretArg := regexp.MustCompile("^-{1,2}(user|pass|auth)(=.*)?$")
|
||||
routeURLArg := regexp.MustCompile("^-{1,2}(routes)(=.*)?$")
|
||||
singleURLArg := regexp.MustCompile("^-{1,2}(cluster|cluster_listen)(=.*)?$")
|
||||
for i, arg := range args {
|
||||
if secret.MatchString(arg) {
|
||||
if idx := strings.Index(arg, "="); idx != -1 {
|
||||
args[i] = arg[:idx] + "=[REDACTED]"
|
||||
} else if i+1 < len(args) {
|
||||
args[i+1] = "[REDACTED]"
|
||||
}
|
||||
switch {
|
||||
case secretArg.MatchString(arg):
|
||||
redactArgValue(args, i, func(_ string) string { return "[REDACTED]" })
|
||||
case routeURLArg.MatchString(arg):
|
||||
redactArgValue(args, i, redactURLListUser)
|
||||
case singleURLArg.MatchString(arg):
|
||||
redactArgValue(args, i, redactURLUser)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func redactArgValue(args []string, i int, redact func(string) string) {
|
||||
if flag, value, ok := strings.Cut(args[i], "="); ok {
|
||||
args[i] = flag + "=" + redact(value)
|
||||
} else if i+1 < len(args) {
|
||||
args[i+1] = redact(args[i+1])
|
||||
}
|
||||
}
|
||||
|
||||
func redactURLUser(raw string) string {
|
||||
if !strings.Contains(raw, "@") {
|
||||
return raw
|
||||
}
|
||||
parseValue := strings.TrimSpace(raw)
|
||||
restoreRandom := false
|
||||
if prefix, ok := strings.CutSuffix(parseValue, ":-1"); ok {
|
||||
parseValue = prefix + ":0"
|
||||
restoreRandom = true
|
||||
}
|
||||
u, err := url.Parse(parseValue)
|
||||
if err != nil || u.User == nil {
|
||||
return raw
|
||||
}
|
||||
// url.String escapes brackets in userinfo, so use
|
||||
// a placeholder here and rewrite it afterward.
|
||||
u.User = url.User("_REDACTED_")
|
||||
if restoreRandom {
|
||||
u.Host = strings.TrimSuffix(u.Host, ":0") + ":-1"
|
||||
}
|
||||
return strings.Replace(u.String(), "_REDACTED_@", "[REDACTED]@", 1)
|
||||
}
|
||||
|
||||
func redactURLListUser(raw string) string {
|
||||
parts := strings.Split(raw, ",")
|
||||
for i, part := range parts {
|
||||
parts[i] = redactURLUser(part)
|
||||
}
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
+37
-22
@@ -166,32 +166,40 @@ func (c *client) parse(buf []byte) error {
|
||||
goto authErr
|
||||
}
|
||||
var ok bool
|
||||
// Check here for NoAuthUser. If this is set allow non CONNECT protos as our first.
|
||||
// E.g. telnet proto demos.
|
||||
if noAuthUser := s.getOpts().NoAuthUser; noAuthUser != _EMPTY_ {
|
||||
s.mu.Lock()
|
||||
user, exists := s.users[noAuthUser]
|
||||
s.mu.Unlock()
|
||||
if exists {
|
||||
c.RegisterUser(user)
|
||||
c.mu.Lock()
|
||||
c.clearAuthTimer()
|
||||
c.flags.set(connectReceived)
|
||||
c.mu.Unlock()
|
||||
authSet, ok = false, true
|
||||
switch c.kind {
|
||||
case CLIENT:
|
||||
// Check here for NoAuthUser. If this is set allow non CONNECT protos as our first.
|
||||
// E.g. telnet proto demos.
|
||||
opts := s.getOpts()
|
||||
noAuthUser := opts.NoAuthUser
|
||||
if c.ws != nil {
|
||||
if noAuthUserWS := opts.Websocket.NoAuthUser; noAuthUserWS != _EMPTY_ {
|
||||
noAuthUser = noAuthUserWS
|
||||
}
|
||||
}
|
||||
if noAuthUser != _EMPTY_ {
|
||||
s.mu.Lock()
|
||||
user, exists := s.users[noAuthUser]
|
||||
s.mu.Unlock()
|
||||
if exists {
|
||||
c.RegisterUser(user)
|
||||
c.mu.Lock()
|
||||
c.clearAuthTimer()
|
||||
c.flags.set(connectReceived)
|
||||
c.mu.Unlock()
|
||||
authSet, ok = false, true
|
||||
}
|
||||
}
|
||||
case LEAF:
|
||||
// Compressed inbound leaf-node negotiation may require INFO
|
||||
// before CONNECT. Without compression, leaf connections must
|
||||
// still start with CONNECT.
|
||||
ok = (b == 'I' || b == 'i') && needsCompression(s.getOpts().LeafNode.Compression.Mode)
|
||||
}
|
||||
if !ok {
|
||||
goto authErr
|
||||
}
|
||||
}
|
||||
// If the connection is a gateway connection, make sure that
|
||||
// if this is an inbound, it starts with a CONNECT.
|
||||
if c.kind == GATEWAY && !c.gw.outbound && !c.gw.connected {
|
||||
// Use auth violation since no CONNECT was sent.
|
||||
// It could be a parseErr too.
|
||||
goto authErr
|
||||
}
|
||||
}
|
||||
switch b {
|
||||
case 'P', 'p':
|
||||
@@ -1250,10 +1258,17 @@ func protoSnippet(start, max int, buf []byte) string {
|
||||
// If so, an error is sent to the client and the connection is closed.
|
||||
// The error ErrMaxControlLine is returned.
|
||||
func (c *client) overMaxControlLineLimit(arg []byte, mcl int32) error {
|
||||
// Widen to int64 so mcl*16 cannot overflow for large configured values.
|
||||
effective := int64(mcl)
|
||||
if c.kind != CLIENT {
|
||||
return nil
|
||||
// This is the upper bound on argBuf length for LEAF, ROUTER, and GATEWAY connections.
|
||||
// These kinds need longer arg lines than CLIENT (which is capped at mcl=4096 by default)
|
||||
// because cluster/leaf frames encode origin, account, reply, and queue groups.
|
||||
// By default, this is 64 KB, which matches maxBufSize so a single oversized read
|
||||
// is caught on the very next parse call.
|
||||
effective *= 16
|
||||
}
|
||||
if len(arg) > int(mcl) {
|
||||
if int64(len(arg)) > effective {
|
||||
err := NewErrorCtx(ErrMaxControlLine, "State %d, max_control_line %d, Buffer len %d (snip: %s...)",
|
||||
c.state, int(mcl), len(c.argBuf), protoSnippet(0, MAX_CONTROL_LINE_SNIPPET_SIZE, arg))
|
||||
c.sendErr(err.Error())
|
||||
|
||||
+198
-55
@@ -89,6 +89,7 @@ type RaftNode interface {
|
||||
RecreateInternalSubs() error
|
||||
IsSystemAccount() bool
|
||||
GetTrafficAccountName() string
|
||||
GetWriteErr() error
|
||||
}
|
||||
|
||||
// RaftNodeCheckpoint is used as an alternative to a direct InstallSnapshot.
|
||||
@@ -248,6 +249,9 @@ type raft struct {
|
||||
scaleUp bool // The node is part of a scale up, puts us in observer mode until the log contains data.
|
||||
deleted bool // If the node was deleted.
|
||||
snapshotting bool // Snapshot is in progress.
|
||||
quorumPaused bool // Pause replication and quorum participation to prevent log growth during slow applies.
|
||||
|
||||
overrunCount uint64 // Counter of how many times we were overrun, either as follower or as leader.
|
||||
}
|
||||
|
||||
type proposedEntry struct {
|
||||
@@ -487,9 +491,12 @@ func (s *Server) initRaftNode(accName string, cfg *RaftConfig, labels pprofLabel
|
||||
n.papplied = 0
|
||||
if _, ok := n.wal.(*memStore); ok {
|
||||
_ = os.RemoveAll(filepath.Join(n.sd, snapshotsDir))
|
||||
} else {
|
||||
// See if we have any snapshots and if so load and process on startup.
|
||||
n.setupLastSnapshot()
|
||||
} else if err := n.setupLastSnapshot(); err != nil && err != errNoSnapAvailable {
|
||||
// If we failed to recover from the snapshot, then we should surface
|
||||
// the error upwards, otherwise we can complete recovery but have only
|
||||
// a partial view of the world.
|
||||
n.shutdown()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// We may have restored the peer state from the
|
||||
@@ -503,6 +510,7 @@ func (s *Server) initRaftNode(accName string, cfg *RaftConfig, labels pprofLabel
|
||||
|
||||
// Make sure that the snapshots directory exists.
|
||||
if err := os.MkdirAll(filepath.Join(n.sd, snapshotsDir), defaultDirPerms); err != nil {
|
||||
n.shutdown()
|
||||
return nil, fmt.Errorf("could not create snapshots directory - %v", err)
|
||||
}
|
||||
|
||||
@@ -521,12 +529,6 @@ func (s *Server) initRaftNode(accName string, cfg *RaftConfig, labels pprofLabel
|
||||
|
||||
if state.Msgs > 0 {
|
||||
n.debug("Replaying state of %d entries", state.Msgs)
|
||||
if first, err := n.loadFirstEntry(); err == nil {
|
||||
n.pterm, n.pindex = first.pterm, first.pindex
|
||||
if first.commit > 0 && first.commit > n.commit {
|
||||
n.commit = first.commit
|
||||
}
|
||||
}
|
||||
|
||||
// This process will queue up entries on our applied queue but prior to the upper
|
||||
// state machine running. So we will monitor how much we have queued and if we
|
||||
@@ -537,6 +539,25 @@ func (s *Server) initRaftNode(accName string, cfg *RaftConfig, labels pprofLabel
|
||||
// yet. Replay them.
|
||||
for index, qsz := state.FirstSeq, 0; index <= state.LastSeq; index++ {
|
||||
ae, err := n.loadEntry(index)
|
||||
// The first entry in our WAL initializes state but must align with our snapshot if we had one.
|
||||
// Importantly, check this first, as we might need to truncate the WAL further than the index.
|
||||
if index == state.FirstSeq {
|
||||
// If the entry is missing, corrupt, or doesn't align with the snapshot, truncate the WAL.
|
||||
if err != nil || ae == nil || ae.pindex != index-1 || n.pindex != ae.pindex {
|
||||
if err != nil {
|
||||
n.warn("Could not load %d from WAL [%+v]: %v", index, state, err)
|
||||
} else {
|
||||
n.warn("Misaligned WAL, will truncate")
|
||||
}
|
||||
// Truncate to the snapshot or beginning if there is none.
|
||||
truncateAndErr(n.pindex)
|
||||
break
|
||||
}
|
||||
n.pterm, n.pindex = ae.pterm, ae.pindex
|
||||
if ae.commit > 0 && ae.commit > n.commit {
|
||||
n.commit = ae.commit
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
n.warn("Could not load %d from WAL [%+v]: %v", index, state, err)
|
||||
// Truncate to the previous correct entry.
|
||||
@@ -902,6 +923,16 @@ func (n *raft) Propose(data []byte) error {
|
||||
if werr := n.werr; werr != nil {
|
||||
return werr
|
||||
}
|
||||
|
||||
if n.isLeaderOverrun() {
|
||||
var state StreamState
|
||||
n.wal.FastState(&state)
|
||||
n.warn("Leader falling behind, stepping down: pindex %d, commit %d, applied %d, WAL size %s", n.pindex, n.commit, n.applied, friendlyBytes(state.Bytes))
|
||||
// Stepdown without leader transfer, likely all replicas will be overrun, and we need time to recover.
|
||||
n.stepdownLocked(noLeader)
|
||||
n.overrunCount++
|
||||
return errNotLeader
|
||||
}
|
||||
n.prop.push(newProposedEntry(newEntry(EntryNormal, data), _EMPTY_))
|
||||
return nil
|
||||
}
|
||||
@@ -921,12 +952,39 @@ func (n *raft) ProposeMulti(entries []*Entry) error {
|
||||
if werr := n.werr; werr != nil {
|
||||
return werr
|
||||
}
|
||||
|
||||
if n.isLeaderOverrun() {
|
||||
var state StreamState
|
||||
n.wal.FastState(&state)
|
||||
n.warn("Leader falling behind, stepping down: pindex %d, commit %d, applied %d, WAL size %s", n.pindex, n.commit, n.applied, friendlyBytes(state.Bytes))
|
||||
// Stepdown without leader transfer, likely all replicas will be overrun, and we need time to recover.
|
||||
n.stepdownLocked(noLeader)
|
||||
n.overrunCount++
|
||||
return errNotLeader
|
||||
}
|
||||
for _, e := range entries {
|
||||
n.prop.push(newProposedEntry(e, _EMPTY_))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isLeaderOverrun returns whether we are overrun and should step down due to continuously increasing
|
||||
// uncommitted or unapplied entries. If triggered, this means we're being severely overrun by
|
||||
// incoming proposals or the system is degraded such that it's too slow (or unable) to process them.
|
||||
// Stepping down means the system gets to "breathe" for a bit, until a new leader can be elected.
|
||||
// Lock should be held.
|
||||
func (n *raft) isLeaderOverrun() bool {
|
||||
applied := max(n.applied, n.papplied)
|
||||
commit := max(n.commit, n.papplied)
|
||||
// We only do this past a high threshold to protect ourselves.
|
||||
// Worst-case we'll have 2x the threshold, once in uncommitted and once in unapplied entries.
|
||||
// Either the number of uncommitted entries is over the threshold: we're not getting quorum from our followers.
|
||||
uncommittedThreshold := n.pindex > commit && n.pindex-commit > pauseQuorumThreshold
|
||||
// Or, the number of in-memory committed but not yet applied entries is over the threshold: we're slow to apply.
|
||||
unappliedThreshold := commit > applied && commit-applied > pauseQuorumThreshold
|
||||
return uncommittedThreshold || unappliedThreshold
|
||||
}
|
||||
|
||||
// ForwardProposal will forward the proposal to the leader if known.
|
||||
// If we are the leader this is the same as calling propose.
|
||||
func (n *raft) ForwardProposal(entry []byte) error {
|
||||
@@ -1075,8 +1133,8 @@ func (n *raft) PauseApply() error {
|
||||
}
|
||||
|
||||
func (n *raft) pauseApplyLocked() {
|
||||
// If we are currently a candidate make sure we step down.
|
||||
if n.State() == Candidate {
|
||||
// If we are currently not a follower, make sure we step down.
|
||||
if n.State() != Follower {
|
||||
n.stepdownLocked(noLeader)
|
||||
}
|
||||
|
||||
@@ -1558,11 +1616,14 @@ func termAndIndexFromSnapFile(sn string) (term, index uint64, err error) {
|
||||
// setupLastSnapshot is called at startup to try and recover the last snapshot from
|
||||
// the disk if possible. We will try to recover the term, index and commit/applied
|
||||
// indices and then notify the upper layer what we found. Compacts the WAL if needed.
|
||||
func (n *raft) setupLastSnapshot() {
|
||||
func (n *raft) setupLastSnapshot() error {
|
||||
snapDir := filepath.Join(n.sd, snapshotsDir)
|
||||
psnaps, err := os.ReadDir(snapDir)
|
||||
if err != nil {
|
||||
return
|
||||
if os.IsNotExist(err) {
|
||||
return errNoSnapAvailable
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var lterm, lindex uint64
|
||||
@@ -1586,18 +1647,8 @@ func (n *raft) setupLastSnapshot() {
|
||||
os.Remove(sfile)
|
||||
}
|
||||
}
|
||||
|
||||
// Now cleanup any old entries
|
||||
for _, sf := range psnaps {
|
||||
sfile := filepath.Join(snapDir, sf.Name())
|
||||
if sfile != latest {
|
||||
n.debug("Removing old snapshot: %q", sfile)
|
||||
os.Remove(sfile)
|
||||
}
|
||||
}
|
||||
|
||||
if latest == _EMPTY_ {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set latest snapshot we have.
|
||||
@@ -1607,13 +1658,7 @@ func (n *raft) setupLastSnapshot() {
|
||||
n.snapfile = latest
|
||||
snap, err := n.loadLastSnapshot()
|
||||
if err != nil {
|
||||
// We failed to recover the last snapshot for some reason, so we will
|
||||
// assume it has been corrupted and will try to delete it.
|
||||
if n.snapfile != _EMPTY_ {
|
||||
os.Remove(n.snapfile)
|
||||
n.snapfile = _EMPTY_
|
||||
}
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
// We successfully recovered the last snapshot from the disk.
|
||||
@@ -1627,8 +1672,8 @@ func (n *raft) setupLastSnapshot() {
|
||||
n.papplied = snap.lastIndex
|
||||
// Restore the peerState
|
||||
ps, err := decodePeerState(snap.peerstate)
|
||||
if err == nil {
|
||||
n.processPeerState(ps)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n.processPeerState(ps)
|
||||
n.extSt = ps.domainExt
|
||||
@@ -1636,7 +1681,19 @@ func (n *raft) setupLastSnapshot() {
|
||||
n.apply.push(newCommittedEntry(n.commit, []*Entry{{EntrySnapshot, snap.data}}))
|
||||
if _, err := n.wal.Compact(snap.lastIndex + 1); err != nil {
|
||||
n.setWriteErrLocked(err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Now cleanup any old entries. We only do this once we know that the
|
||||
// latest snapshot was OK.
|
||||
for _, sf := range psnaps {
|
||||
if sfile := filepath.Join(snapDir, sf.Name()); sfile != latest {
|
||||
n.debug("Removing old snapshot: %q", sfile)
|
||||
os.Remove(sfile)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadLastSnapshot will load and return our last snapshot.
|
||||
@@ -1652,14 +1709,10 @@ func (n *raft) loadLastSnapshot() (*snapshot, error) {
|
||||
|
||||
if err != nil {
|
||||
n.warn("Error reading snapshot: %v", err)
|
||||
os.Remove(n.snapfile)
|
||||
n.snapfile = _EMPTY_
|
||||
return nil, err
|
||||
}
|
||||
if len(buf) < minSnapshotLen {
|
||||
n.warn("Snapshot corrupt, too short")
|
||||
os.Remove(n.snapfile)
|
||||
n.snapfile = _EMPTY_
|
||||
return nil, errSnapshotCorrupt
|
||||
}
|
||||
|
||||
@@ -1671,8 +1724,6 @@ func (n *raft) loadLastSnapshot() (*snapshot, error) {
|
||||
var hb [highwayhash.Size64]byte
|
||||
if !bytes.Equal(lchk[:], n.hh.Sum(hb[:0])) {
|
||||
n.warn("Snapshot corrupt, checksums did not match")
|
||||
os.Remove(n.snapfile)
|
||||
n.snapfile = _EMPTY_
|
||||
return nil, errSnapshotCorrupt
|
||||
}
|
||||
|
||||
@@ -1686,12 +1737,12 @@ func (n *raft) loadLastSnapshot() (*snapshot, error) {
|
||||
}
|
||||
|
||||
// We had a bug in 2.9.12 that would allow snapshots on last index of 0.
|
||||
// Detect that here and return err.
|
||||
// Detect that and continue anyway, nothing else we can do about it.
|
||||
if snap.lastIndex == 0 {
|
||||
n.warn("Snapshot with last index 0 is invalid, cleaning up")
|
||||
os.Remove(n.snapfile)
|
||||
n.snapfile = _EMPTY_
|
||||
return nil, errSnapshotCorrupt
|
||||
return nil, errNoSnapAvailable
|
||||
}
|
||||
|
||||
return snap, nil
|
||||
@@ -2733,9 +2784,9 @@ func decodeAppendEntry(msg []byte, sub *subscription, reply string) (*appendEntr
|
||||
ae.reply, ae.sub = reply, sub
|
||||
|
||||
// Decode Entries.
|
||||
ne, ri := int(le.Uint16(msg[40:])), uint64(42)
|
||||
ne, ri := int(le.Uint16(msg[40:])), uint64(appendEntryBaseLen)
|
||||
for i, max := 0, uint64(len(msg)); i < ne; i++ {
|
||||
if ri >= max-1 {
|
||||
if max-ri < 4 {
|
||||
return nil, errBadAppendEntry
|
||||
}
|
||||
ml := uint64(le.Uint32(msg[ri:]))
|
||||
@@ -2867,20 +2918,44 @@ func (n *raft) handleForwardedProposal(sub *subscription, c *client, _ *Account,
|
||||
msg = copyBytes(msg)
|
||||
|
||||
n.RLock()
|
||||
prop := n.prop
|
||||
// Check state under lock, we might not be leader anymore.
|
||||
if n.State() != Leader || !n.leaderState.Load() {
|
||||
n.debug("Ignoring forwarded proposal, not leader")
|
||||
n.RUnlock()
|
||||
return
|
||||
}
|
||||
prop, werr := n.prop, n.werr
|
||||
n.RUnlock()
|
||||
|
||||
// Ignore if we have had a write error previous.
|
||||
if werr != nil {
|
||||
if n.werr != nil {
|
||||
n.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
if n.isLeaderOverrun() {
|
||||
n.RUnlock()
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
// Now that we've reacquired as write lock, we need to make sure that everything we
|
||||
// believed before is still true. Otherwise we've either stepped down already from
|
||||
// another goroutine or we've stopped being overrun and shouldn't drop the entry.
|
||||
if n.State() != Leader || !n.leaderState.Load() {
|
||||
return
|
||||
} else if !n.isLeaderOverrun() {
|
||||
prop.push(newProposedEntry(newEntry(EntryNormal, msg), reply))
|
||||
return
|
||||
}
|
||||
var state StreamState
|
||||
n.wal.FastState(&state)
|
||||
n.warn("Leader falling behind, stepping down: pindex %d, commit %d, applied %d, WAL size %s", n.pindex, n.commit, n.applied, friendlyBytes(state.Bytes))
|
||||
// Stepdown without leader transfer, likely all replicas will be overrun, and we need time to recover.
|
||||
n.stepdownLocked(noLeader)
|
||||
n.overrunCount++
|
||||
return
|
||||
}
|
||||
// Possible that we could fall through to here from multiple connections but if
|
||||
// one does end up stepping down then the proposal queue gets drained anyway.
|
||||
n.RUnlock()
|
||||
prop.push(newProposedEntry(newEntry(EntryNormal, msg), reply))
|
||||
}
|
||||
|
||||
@@ -3252,8 +3327,6 @@ func (n *raft) sendSnapshotToFollower(subject string) (uint64, error) {
|
||||
if err != nil {
|
||||
// We need to stepdown here when this happens.
|
||||
n.stepdownLocked(noLeader)
|
||||
// We need to reset our state here as well.
|
||||
n.resetWAL()
|
||||
return 0, err
|
||||
}
|
||||
// Go ahead and send the snapshot and peerstate here as first append entry to the catchup follower.
|
||||
@@ -4041,6 +4114,42 @@ func (n *raft) processAppendEntry(ae *appendEntry, sub *subscription) {
|
||||
}
|
||||
}
|
||||
|
||||
// If commits are outpacing our applies, temporarily stop accepting new entries to avoid falling further behind.
|
||||
// This encourages the leader to sync us via a snapshot instead. We use max(applied, papplied) to avoid
|
||||
// incorrectly triggering this pause immediately after receiving a snapshot.
|
||||
applied := max(n.applied, n.papplied)
|
||||
commit := max(n.commit, n.papplied)
|
||||
if sub != nil && (commit > applied || n.quorumPaused) {
|
||||
diff := commit - applied
|
||||
if n.quorumPaused {
|
||||
if diff > paeWarnThreshold {
|
||||
if catchingUp {
|
||||
n.cancelCatchup()
|
||||
}
|
||||
n.Unlock()
|
||||
return
|
||||
}
|
||||
// Once we're sufficiently below the threshold, we continue again. We'll likely receive a snapshot
|
||||
// from the leader.
|
||||
n.quorumPaused = false
|
||||
var state StreamState
|
||||
n.wal.FastState(&state)
|
||||
n.warn("Quorum resumed: commit %d, applied %d, WAL size %s", commit, applied, friendlyBytes(state.Bytes))
|
||||
} else if diff > pauseQuorumThreshold {
|
||||
// It takes a while until we reach the pause threshold, but once we do we enter a "cooldown period".
|
||||
n.quorumPaused = true
|
||||
n.overrunCount++
|
||||
var state StreamState
|
||||
n.wal.FastState(&state)
|
||||
n.warn("Quorum paused, falling behind: commit %d != applied %d, WAL size %s", commit, applied, friendlyBytes(state.Bytes))
|
||||
if catchingUp {
|
||||
n.cancelCatchup()
|
||||
}
|
||||
n.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if ae.pterm != n.pterm || ae.pindex != n.pindex {
|
||||
// Check if this is a lower or equal index than what we were expecting.
|
||||
if ae.pindex <= n.pindex {
|
||||
@@ -4086,9 +4195,26 @@ func (n *raft) processAppendEntry(ae *appendEntry, sub *subscription) {
|
||||
}
|
||||
} else {
|
||||
// If terms mismatched, delete that entry and all others past it.
|
||||
// Make sure to cancel any catchups in progress.
|
||||
// Truncate will reset our pterm and pindex. Only do so if we have an entry.
|
||||
n.truncateWAL(eae.pterm, eae.pindex)
|
||||
// But only if we haven't already committed past this point.
|
||||
if eae.pindex < n.commit {
|
||||
success = true
|
||||
assert.Unreachable("Truncate to earlier entry would lose commits", map[string]any{
|
||||
"n.accName": n.accName,
|
||||
"n.group": n.group,
|
||||
"n.id": n.id,
|
||||
"n.term": n.term,
|
||||
"n.pindex": n.pindex,
|
||||
"n.commit": n.commit,
|
||||
"n.applied": n.applied,
|
||||
"ae.pindex": ae.pindex,
|
||||
"ae.pterm": ae.pterm,
|
||||
"ae.commit": ae.commit,
|
||||
"eae.pterm": eae.pterm,
|
||||
"eae.pindex": eae.pindex,
|
||||
})
|
||||
} else {
|
||||
n.truncateWAL(eae.pterm, eae.pindex)
|
||||
}
|
||||
}
|
||||
// Cancel regardless if unsuccessful.
|
||||
if !success {
|
||||
@@ -4420,9 +4546,10 @@ func (n *raft) storeToWAL(ae *appendEntry) error {
|
||||
}
|
||||
|
||||
const (
|
||||
paeDropThreshold = 20_000
|
||||
paeWarnThreshold = 10_000
|
||||
paeWarnModulo = 5_000
|
||||
pauseQuorumThreshold = 100_000
|
||||
paeDropThreshold = 20_000
|
||||
paeWarnThreshold = 10_000
|
||||
paeWarnModulo = 5_000
|
||||
)
|
||||
|
||||
func (n *raft) sendAppendEntry(entries []*Entry) {
|
||||
@@ -4699,11 +4826,18 @@ func (n *raft) setWriteErrLocked(err error) {
|
||||
}
|
||||
// If this is a not found report but do not disable.
|
||||
if os.IsNotExist(err) {
|
||||
n.error("Resource not found: %v", err)
|
||||
n.warn("Resource not found: %v", err)
|
||||
return
|
||||
}
|
||||
n.error("Critical write error: %v", err)
|
||||
n.werr = err
|
||||
n.shutdown()
|
||||
assert.Unreachable("Raft encountered write error", map[string]any{
|
||||
"n.accName": n.accName,
|
||||
"n.group": n.group,
|
||||
"n.id": n.id,
|
||||
"err": err,
|
||||
})
|
||||
|
||||
if isPermissionError(err) {
|
||||
go n.s.handleWritePermissionError()
|
||||
@@ -4720,6 +4854,13 @@ func (n *raft) isClosed() bool {
|
||||
return n.State() == Closed
|
||||
}
|
||||
|
||||
// GetWriteErr returns the write error (if any).
|
||||
func (n *raft) GetWriteErr() error {
|
||||
n.RLock()
|
||||
defer n.RUnlock()
|
||||
return n.werr
|
||||
}
|
||||
|
||||
// Capture our write error if any and hold.
|
||||
func (n *raft) setWriteErr(err error) {
|
||||
n.Lock()
|
||||
@@ -5033,6 +5174,8 @@ func (n *raft) switchToCandidate() {
|
||||
// Increment the term.
|
||||
n.term++
|
||||
n.vote = noVote
|
||||
// Reset quorum paused. If it was previously set, we checked above that we've applied all committed entries.
|
||||
n.quorumPaused = false
|
||||
// Clear current Leader.
|
||||
n.updateLeader(noLeader)
|
||||
n.switchState(Candidate)
|
||||
|
||||
+480
-272
@@ -1,4 +1,4 @@
|
||||
// Copyright 2017-2025 The NATS Authors
|
||||
// Copyright 2017-2026 The NATS Authors
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
@@ -740,6 +740,35 @@ func (jso jetStreamOption) IsStatszChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type jetStreamLimitsOption struct {
|
||||
noopOption
|
||||
newMaxMemory int64
|
||||
newMaxStore int64
|
||||
}
|
||||
|
||||
func (jso *jetStreamLimitsOption) Apply(s *Server) {
|
||||
js := s.getJetStream()
|
||||
if js == nil {
|
||||
return
|
||||
}
|
||||
js.mu.Lock()
|
||||
if jso.newMaxMemory > 0 {
|
||||
js.config.MaxMemory = jso.newMaxMemory
|
||||
atomic.StoreInt64(&js.memMax, js.config.MaxMemory)
|
||||
s.Noticef("Reloaded: JetStream max_mem_store = %s", friendlyBytes(jso.newMaxMemory))
|
||||
}
|
||||
if jso.newMaxStore > 0 {
|
||||
js.config.MaxStore = jso.newMaxStore
|
||||
atomic.StoreInt64(&js.storeMax, js.config.MaxStore)
|
||||
s.Noticef("Reloaded: JetStream max_file_store = %s", friendlyBytes(jso.newMaxStore))
|
||||
}
|
||||
js.mu.Unlock()
|
||||
}
|
||||
|
||||
func (jso *jetStreamLimitsOption) IsStatszChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type defaultSentinelOption struct {
|
||||
noopOption
|
||||
newValue string
|
||||
@@ -872,9 +901,207 @@ func (o *profBlockRateReload) Apply(s *Server) {
|
||||
|
||||
type leafNodeOption struct {
|
||||
noopOption
|
||||
tlsFirstChanged bool
|
||||
compressionChanged bool
|
||||
// These are for the remotes
|
||||
added []*RemoteLeafOpts
|
||||
changed map[*leafNodeCfg]*remoteLeafOption
|
||||
}
|
||||
|
||||
type remoteLeafOption struct {
|
||||
tlsFirstChanged bool
|
||||
compressionChanged bool
|
||||
disabledChanged bool
|
||||
opts *RemoteLeafOpts
|
||||
}
|
||||
|
||||
// Given `old` and `new` Leafnode options, this function will return the structure
|
||||
// used for applying the configuration, or an error is there are changes that
|
||||
// are not supported.
|
||||
func getLeafNodeOptionsChanges(s *Server, old, new *LeafNodeOpts) (*leafNodeOption, error) {
|
||||
|
||||
// We can't use DeepEqual for `Users` field, so do custom check.
|
||||
if usersHaveChanged(old.Users, new.Users) {
|
||||
return nil, fmt.Errorf("field \"Users\": old=%v, new=%v", old.Users, new.Users)
|
||||
}
|
||||
|
||||
// Check the main leafnodes{} block to see if there are any changes that are
|
||||
// not supported. We provide a list of fields to ignore (we already checked,
|
||||
// allow them to be modified or will check later).
|
||||
if err := checkConfigsEqual(old, new, []string{
|
||||
"Compression",
|
||||
"Remotes",
|
||||
"TLSHandshakeFirst",
|
||||
"TLSHandshakeFirstFallback",
|
||||
"TLSConfig",
|
||||
"Users",
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
const (
|
||||
remoteErrFormat = "remote %s: %s"
|
||||
maxAttempts = 20
|
||||
)
|
||||
var (
|
||||
nlo *leafNodeOption
|
||||
// Track whether any existing remote was not found (i.e. removed).
|
||||
removed bool
|
||||
)
|
||||
|
||||
forLoop:
|
||||
for failed := range maxAttempts {
|
||||
removed = false
|
||||
if failed > 0 {
|
||||
// If we failed once, we will wait a bit before trying again the remotes.
|
||||
// This could give enough time for connections that were in progress to complete.
|
||||
select {
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
case <-s.quitCh:
|
||||
return nil, ErrServerNotRunning
|
||||
}
|
||||
}
|
||||
nlo = &leafNodeOption{
|
||||
tlsFirstChanged: (old.TLSHandshakeFirst != new.TLSHandshakeFirst || old.TLSHandshakeFirstFallback != new.TLSHandshakeFirstFallback),
|
||||
compressionChanged: !old.Compression.equals(&new.Compression),
|
||||
// Start with all, will update when processing existing ones.
|
||||
// Since the list will be modified, we need to clone it.
|
||||
added: slices.Clone(new.Remotes),
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
// Go through the list of existing remote configurations.
|
||||
for lrc := range s.leafRemoteCfgs {
|
||||
var rlo *RemoteLeafOpts
|
||||
// Look for the corresponding `*RemoteLeafOpts` in the `nlo.added`
|
||||
// list. If it is found, that function returns an updated list
|
||||
// with the element removed from it.
|
||||
lrc.RLock()
|
||||
rlo, nlo.added = getRemoteLeafOpts(lrc.name(), nlo.added)
|
||||
if rlo == nil {
|
||||
// Not found, will be removed in leafNodeOption.Apply().
|
||||
removed = true
|
||||
lrc.RUnlock()
|
||||
continue
|
||||
}
|
||||
// Now we need to make sure that there are no changes that we don't
|
||||
// support for a RemoteLeafOpts.
|
||||
err := checkConfigsEqual(lrc.RemoteLeafOpts, rlo, []string{
|
||||
"Compression",
|
||||
"Disabled",
|
||||
"TLS",
|
||||
"TLSHandshakeFirst",
|
||||
"TLSConfig",
|
||||
})
|
||||
if err != nil {
|
||||
lrc.RUnlock()
|
||||
s.mu.RUnlock()
|
||||
return nil, fmt.Errorf(remoteErrFormat, rlo.safeName(), err)
|
||||
}
|
||||
disabledChanged := lrc.Disabled != rlo.Disabled
|
||||
// If this remote was disabled and is now enabled, we need to make sure
|
||||
// that there is no connect in progress. If that is the case, either
|
||||
// try again (if it is the first failure) or return an error.
|
||||
if disabledChanged && lrc.Disabled && lrc.connInProgress {
|
||||
lrc.RUnlock()
|
||||
s.mu.RUnlock()
|
||||
if failed < maxAttempts-1 {
|
||||
continue forLoop
|
||||
}
|
||||
return nil, fmt.Errorf(remoteErrFormat, rlo.safeName(),
|
||||
"cannot be enabled at the moment, try again")
|
||||
}
|
||||
// Since we will use the new `rlo.TLSConfig` later on, consider all
|
||||
// existing remote configs as "changed" and store them in the
|
||||
// `nlo.changed` map.
|
||||
if nlo.changed == nil {
|
||||
nlo.changed = make(map[*leafNodeCfg]*remoteLeafOption)
|
||||
}
|
||||
lnro := &remoteLeafOption{
|
||||
tlsFirstChanged: lrc.TLSHandshakeFirst != rlo.TLSHandshakeFirst,
|
||||
compressionChanged: !lrc.Compression.equals(&rlo.Compression),
|
||||
disabledChanged: disabledChanged,
|
||||
opts: rlo,
|
||||
}
|
||||
lrc.RUnlock()
|
||||
nlo.changed[lrc] = lnro
|
||||
}
|
||||
if len(nlo.added) > 0 {
|
||||
// Go through the added list and check if an added was recently removed and,
|
||||
// if that is the case, is it still in the `s.rmLeafRemoteCfgs` map, which
|
||||
// may mean that there was a connect-in-progress that did not complete yet.
|
||||
// Either try again (if it is the first failure) or return an error.
|
||||
for _, rlo := range nlo.added {
|
||||
if _, cip := s.rmLeafRemoteCfgs[rlo.name()]; cip {
|
||||
s.mu.RUnlock()
|
||||
if failed < maxAttempts-1 {
|
||||
continue forLoop
|
||||
}
|
||||
return nil, fmt.Errorf(remoteErrFormat, rlo.safeName(),
|
||||
"cannot be added at the moment, try again")
|
||||
}
|
||||
}
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
break
|
||||
}
|
||||
|
||||
// Now we want to make sure that there were actual changes, so that we don't
|
||||
// cause a reload of leafnodes for nothing. However, if one has (or all have)
|
||||
// been removed we still need to invoke leafNodeOption.Apply().
|
||||
if !nlo.tlsFirstChanged && !nlo.compressionChanged && !removed && len(nlo.added) == 0 && len(nlo.changed) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nlo, nil
|
||||
}
|
||||
|
||||
func usersHaveChanged(ousers, nusers []*User) bool {
|
||||
if len(ousers) != len(nusers) {
|
||||
return true
|
||||
}
|
||||
// We did not do a strict list order check in the past, so maintain this to
|
||||
// avoid possible breaking changes.
|
||||
oua := make(map[string]*User, len(ousers))
|
||||
nua := make(map[string]*User, len(nusers))
|
||||
for _, u := range ousers {
|
||||
oua[u.Username] = u
|
||||
}
|
||||
for _, u := range nusers {
|
||||
nua[u.Username] = u
|
||||
}
|
||||
for uname, u := range oua {
|
||||
// If we can not find new one with same name, consider that they have changed.
|
||||
nu, ok := nua[uname]
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
// Same if password or account has changed.
|
||||
if u.Password != nu.Password || u.Account.GetName() != nu.Account.GetName() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Given the `search` remote leafnode options name, searches for a match in the `list`.
|
||||
// If found, returns the `*RemoteLeafOpts` from the list, and the updated list
|
||||
// without the element in it. If not found, returns `nil` and the unmodified list.
|
||||
func getRemoteLeafOpts(search string, list []*RemoteLeafOpts) (*RemoteLeafOpts, []*RemoteLeafOpts) {
|
||||
for i, rlo := range list {
|
||||
if search == rlo.name() {
|
||||
lastIdx := len(list) - 1
|
||||
if lastIdx == 0 {
|
||||
return rlo, nil
|
||||
}
|
||||
if i < lastIdx {
|
||||
list[i] = list[lastIdx]
|
||||
}
|
||||
list = list[:lastIdx]
|
||||
return rlo, list
|
||||
}
|
||||
}
|
||||
return nil, list
|
||||
}
|
||||
|
||||
func (l *leafNodeOption) Apply(s *Server) {
|
||||
@@ -882,101 +1109,181 @@ func (l *leafNodeOption) Apply(s *Server) {
|
||||
if l.tlsFirstChanged {
|
||||
s.Noticef("Reloaded: LeafNode TLS HandshakeFirst value is: %v", opts.LeafNode.TLSHandshakeFirst)
|
||||
s.Noticef("Reloaded: LeafNode TLS HandshakeFirstFallback value is: %v", opts.LeafNode.TLSHandshakeFirstFallback)
|
||||
for _, r := range opts.LeafNode.Remotes {
|
||||
s.Noticef("Reloaded: LeafNode Remote to %v TLS HandshakeFirst value is: %v", r.URLs, r.TLSHandshakeFirst)
|
||||
}
|
||||
}
|
||||
if l.compressionChanged || l.disabledChanged {
|
||||
var leafs []*client
|
||||
var solicit []*leafNodeCfg
|
||||
acceptSideCompOpts := &opts.LeafNode.Compression
|
||||
if l.compressionChanged {
|
||||
s.Noticef("Reloaded: LeafNode Compression value is: %v", opts.LeafNode.Compression)
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
// First, update our internal leaf remote configurations with the new
|
||||
// compress options.
|
||||
// Since changing the remotes (as in adding/removing) is currently not
|
||||
// supported, we know that we should have the same number in Options
|
||||
// than in leafRemoteCfgs, but to be sure, use the max size.
|
||||
max := len(opts.LeafNode.Remotes)
|
||||
if l := len(s.leafRemoteCfgs); l < max {
|
||||
max = l
|
||||
}
|
||||
for i := range max {
|
||||
lr := s.leafRemoteCfgs[i]
|
||||
or := opts.LeafNode.Remotes[i]
|
||||
lr.Lock()
|
||||
lr.Compression = or.Compression
|
||||
if lr.Disabled && !or.Disabled {
|
||||
solicit = append(solicit, lr)
|
||||
var close []*client
|
||||
var enable []*leafNodeCfg
|
||||
var removed bool
|
||||
|
||||
s.mu.Lock()
|
||||
acceptSideCompOpts := &opts.LeafNode.Compression
|
||||
// First go over existing leafnode remote configurations and
|
||||
// either remove if no longer present, or update the config.
|
||||
for lrc := range s.leafRemoteCfgs {
|
||||
rlo := l.changed[lrc]
|
||||
if rlo == nil {
|
||||
delete(s.leafRemoteCfgs, lrc)
|
||||
removed = true
|
||||
if s.rmLeafRemoteCfgs == nil {
|
||||
s.rmLeafRemoteCfgs = make(map[string]*leafNodeCfg)
|
||||
}
|
||||
lr.Disabled = or.Disabled
|
||||
lr.Unlock()
|
||||
s.rmLeafRemoteCfgs[lrc.name()] = lrc
|
||||
lrc.markAsRemoved()
|
||||
s.Noticef("Reloaded: LeafNode Remote %s removed", lrc.RemoteLeafOpts.safeName())
|
||||
// We will close the existing connection in the next for-loop.
|
||||
continue
|
||||
}
|
||||
|
||||
for _, l := range s.leafs {
|
||||
var co *CompressionOpts
|
||||
|
||||
l.mu.Lock()
|
||||
if r := l.leaf.remote; r != nil {
|
||||
// If newly marked as disabled, collect and ignore the rest.
|
||||
if r.Disabled {
|
||||
l.flags.set(noReconnect)
|
||||
leafs = append(leafs, l)
|
||||
l.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
co = &r.Compression
|
||||
lrc.Lock()
|
||||
// TLSConfig is always applied.
|
||||
lrc.TLSConfig = rlo.opts.TLSConfig.Clone()
|
||||
// Now update what has been detected has changed.
|
||||
if rlo.tlsFirstChanged {
|
||||
lrc.TLSHandshakeFirst = rlo.opts.TLSHandshakeFirst
|
||||
s.Noticef("Reloaded: LeafNode Remote %s TLS HandshakeFirst value is: %v",
|
||||
lrc.RemoteLeafOpts.safeName(), rlo.opts.TLSHandshakeFirst)
|
||||
}
|
||||
if rlo.compressionChanged {
|
||||
lrc.Compression = rlo.opts.Compression
|
||||
s.Noticef("Reloaded: LeafNode Remote %s Compression value is: %v",
|
||||
lrc.RemoteLeafOpts.safeName(), rlo.opts.Compression)
|
||||
}
|
||||
if rlo.disabledChanged {
|
||||
// Change to new value.
|
||||
lrc.Disabled = rlo.opts.Disabled
|
||||
if lrc.Disabled {
|
||||
lrc.notifyQuitChannel()
|
||||
} else {
|
||||
co = acceptSideCompOpts
|
||||
enable = append(enable, lrc)
|
||||
}
|
||||
newMode := co.Mode
|
||||
// Skip leaf connections that are "not supported" (because they
|
||||
// will never do compression) or the ones that have already the
|
||||
// new compression mode.
|
||||
if l.leaf.compression == CompressionNotSupported || l.leaf.compression == newMode {
|
||||
l.mu.Unlock()
|
||||
s.Noticef("Reloaded: LeafNode Remote %s Disabled value is: %v",
|
||||
lrc.RemoteLeafOpts.safeName(), rlo.opts.Disabled)
|
||||
}
|
||||
lrc.Unlock()
|
||||
}
|
||||
// Second, go over existing leaf connections and apply compression
|
||||
// changes (if applicable) and collect connections that need to be
|
||||
// closed and/or disabled.
|
||||
for _, c := range s.leafs {
|
||||
var co *CompressionOpts
|
||||
|
||||
c.mu.Lock()
|
||||
if r := c.leaf.remote; r != nil {
|
||||
rlo := l.changed[r]
|
||||
// If the config is not in the `changed` map, or the new config says that
|
||||
// the connection is disabled, collect so we can close it after the server
|
||||
// lock is released.
|
||||
if rlo == nil || (rlo.disabledChanged && rlo.opts.Disabled) {
|
||||
c.flags.set(noReconnect)
|
||||
close = append(close, c)
|
||||
c.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
// We need to close the connections if it had compression "off" or the new
|
||||
// mode is compression "off", or if the new mode is "accept", because
|
||||
// these require negotiation.
|
||||
if l.leaf.compression == CompressionOff || newMode == CompressionOff || newMode == CompressionAccept {
|
||||
leafs = append(leafs, l)
|
||||
} else if newMode == CompressionS2Auto {
|
||||
// If the mode is "s2_auto", we need to check if there is really
|
||||
// need to change, and at any rate, we want to save the actual
|
||||
// compression level here, not s2_auto.
|
||||
l.updateS2AutoCompressionLevel(co, &l.leaf.compression)
|
||||
} else {
|
||||
// Simply change the compression writer
|
||||
l.out.cw = s2.NewWriter(nil, s2WriterOptions(newMode)...)
|
||||
l.leaf.compression = newMode
|
||||
if rlo.compressionChanged {
|
||||
co = &r.Compression
|
||||
}
|
||||
l.mu.Unlock()
|
||||
} else if l.compressionChanged {
|
||||
co = acceptSideCompOpts
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
// Close the connections for which negotiation is required, or that
|
||||
// have been disabled.
|
||||
for _, l := range leafs {
|
||||
l.closeConnection(ClientClosed)
|
||||
if co != nil && applyCompressionChanges(c, co) {
|
||||
close = append(close, c)
|
||||
}
|
||||
if l.compressionChanged {
|
||||
s.Noticef("Reloaded: LeafNode compression settings")
|
||||
}
|
||||
if l.disabledChanged {
|
||||
if len(leafs) > 0 {
|
||||
s.Noticef("Reloaded: LeafNode(s) disabled")
|
||||
}
|
||||
if len(solicit) > 0 {
|
||||
for _, remote := range solicit {
|
||||
s.startGoRoutine(func() { s.connectToRemoteLeafNode(remote, true) })
|
||||
}
|
||||
s.Noticef("Reloaded: LeafNode(s) enabled")
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
// Close the connections for which negotiation is required, have been disabled
|
||||
// or simply removed.
|
||||
for _, c := range close {
|
||||
c.closeConnection(ClientClosed)
|
||||
}
|
||||
// Start the ones that have been enabled.
|
||||
for _, r := range enable {
|
||||
s.connectToRemoteLeafNodeAsynchronously(r, true)
|
||||
}
|
||||
// Finally, deal with the ones that have been added.
|
||||
if len(l.added) > 0 {
|
||||
s.solicitLeafNodeRemotes(l.added)
|
||||
}
|
||||
// Deal with removed configs. Make sure there are no connect-in-progress.
|
||||
// If there are still some, have a go routine to check in the background.
|
||||
if removed {
|
||||
if checkAgain := checkRemovedLeafNodeCfgs(s); checkAgain {
|
||||
checkRemovedLeafNodeCfgsAsync(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Go through the removed remote leafnode configs map to check if the
|
||||
// connect-in-progress flag is set. If not, remove from the map.
|
||||
// Returns `true` if there are still some that are in progress.
|
||||
func checkRemovedLeafNodeCfgs(s *Server) bool {
|
||||
var inProgress int
|
||||
s.mu.Lock()
|
||||
for rn, r := range s.rmLeafRemoteCfgs {
|
||||
if r.isConnectInProgress() {
|
||||
inProgress++
|
||||
} else {
|
||||
delete(s.rmLeafRemoteCfgs, rn)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
// Needs to be called again if inProgress > 0
|
||||
return inProgress > 0
|
||||
}
|
||||
|
||||
// Will start a go routine that will periodically call `checkRemovedLeafNodeCfgs`.
|
||||
// When the removed map has been emptied, the go routine will end. It is ok for
|
||||
// this function to be invoked multiple times and have multiple instances running
|
||||
// concurrently.
|
||||
func checkRemovedLeafNodeCfgsAsync(s *Server) {
|
||||
s.startGoRoutine(func() {
|
||||
defer s.grWG.Done()
|
||||
tick := time.NewTicker(50 * time.Millisecond)
|
||||
defer tick.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-tick.C:
|
||||
if checkAgain := checkRemovedLeafNodeCfgs(s); !checkAgain {
|
||||
return
|
||||
}
|
||||
case <-s.quitCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// The `co` compression options are applied to the given leaf connection `c`.
|
||||
// If a "restart" of the connection is needed, will return true, false otherwise.
|
||||
func applyCompressionChanges(c *client, co *CompressionOpts) bool {
|
||||
newMode := co.Mode
|
||||
// Skip leaf connections that are "not supported" (because they
|
||||
// will never do compression) or the ones that have already the
|
||||
// new compression mode.
|
||||
if c.leaf.compression == CompressionNotSupported || c.leaf.compression == newMode {
|
||||
return false
|
||||
}
|
||||
// We need to close the connections if it had compression "off" or the new
|
||||
// mode is compression "off", or if the new mode is "accept", because
|
||||
// these require negotiation.
|
||||
if c.leaf.compression == CompressionOff || newMode == CompressionOff || newMode == CompressionAccept {
|
||||
return true
|
||||
} else if newMode == CompressionS2Auto {
|
||||
// If the mode is "s2_auto", we need to check if there is really
|
||||
// need to change, and at any rate, we want to save the actual
|
||||
// compression level here, not s2_auto.
|
||||
c.updateS2AutoCompressionLevel(co, &c.leaf.compression)
|
||||
} else {
|
||||
// Simply change the compression writer
|
||||
c.out.cw = s2.NewWriter(nil, s2WriterOptions(newMode)...)
|
||||
c.leaf.compression = newMode
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type noFastProdStallReload struct {
|
||||
noopOption
|
||||
noStall bool
|
||||
@@ -1031,7 +1338,7 @@ func (s *Server) recheckPinnedCerts(curOpts *Options, newOpts *Options) {
|
||||
}
|
||||
})
|
||||
}
|
||||
if s.gateway.enabled && reflect.DeepEqual(newOpts.Gateway.TLSPinnedCerts, curOpts.Gateway.TLSPinnedCerts) {
|
||||
if s.gateway.enabled && !reflect.DeepEqual(newOpts.Gateway.TLSPinnedCerts, curOpts.Gateway.TLSPinnedCerts) {
|
||||
gw := s.gateway
|
||||
gw.RLock()
|
||||
for _, c := range gw.out {
|
||||
@@ -1115,11 +1422,6 @@ func (s *Server) ReloadOptions(newOpts *Options) error {
|
||||
|
||||
curOpts := s.getOpts()
|
||||
|
||||
// Wipe trusted keys if needed when we have an operator.
|
||||
if len(curOpts.TrustedOperators) > 0 && len(curOpts.TrustedKeys) > 0 {
|
||||
curOpts.TrustedKeys = nil
|
||||
}
|
||||
|
||||
clientOrgPort := curOpts.Port
|
||||
clusterOrgPort := curOpts.Cluster.Port
|
||||
gatewayOrgPort := curOpts.Gateway.Port
|
||||
@@ -1215,15 +1517,18 @@ func (s *Server) reloadOptions(curOpts, newOpts *Options) error {
|
||||
newOpts.CustomClientAuthentication = curOpts.CustomClientAuthentication
|
||||
newOpts.CustomRouterAuthentication = curOpts.CustomRouterAuthentication
|
||||
|
||||
changed, err := s.diffOptions(newOpts)
|
||||
if err != nil {
|
||||
// Do the validation before checking for differences. We need to ensure
|
||||
// that the new options are valid. Note that there are possible side
|
||||
// effects of calling validateOptions(), in that some default values
|
||||
// may be set, etc... but that should be ok since the current options
|
||||
// went through the same process on startup/previous reload.
|
||||
if err := validateOptions(newOpts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(changed) != 0 {
|
||||
if err := validateOptions(newOpts); err != nil {
|
||||
return err
|
||||
}
|
||||
changed, err := s.diffOptions(newOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create a context that is used to pass special info that we may need
|
||||
@@ -1258,7 +1563,7 @@ func imposeOrder(value any) error {
|
||||
slices.Sort(value.AllowedOrigins)
|
||||
case string, bool, uint8, uint16, uint64, int, int32, int64, time.Duration, float64, nil, LeafNodeOpts, ClusterOpts, *tls.Config, PinnedCertSet,
|
||||
*URLAccResolver, *MemAccResolver, *DirAccResolver, *CacheDirAccResolver, Authentication, MQTTOpts, jwt.TagList,
|
||||
*OCSPConfig, map[string]string, JSLimitOpts, StoreCipher, *OCSPResponseCacheConfig, *ProxiesConfig, WriteTimeoutPolicy:
|
||||
*OCSPConfig, map[string]string, map[string]bool, JSLimitOpts, StoreCipher, *OCSPResponseCacheConfig, *ProxiesConfig, WriteTimeoutPolicy:
|
||||
// explicitly skipped types
|
||||
case *AuthCallout:
|
||||
case JSTpmOpts:
|
||||
@@ -1275,9 +1580,11 @@ func imposeOrder(value any) error {
|
||||
// error.
|
||||
func (s *Server) diffOptions(newOpts *Options) ([]option, error) {
|
||||
var (
|
||||
oldConfig = reflect.ValueOf(s.getOpts()).Elem()
|
||||
oldOpts = s.getOpts()
|
||||
oldConfig = reflect.ValueOf(oldOpts).Elem()
|
||||
newConfig = reflect.ValueOf(newOpts).Elem()
|
||||
diffOpts = []option{}
|
||||
skipTKeys = len(oldOpts.TrustedOperators) > 0 && len(oldOpts.TrustedKeys) > 0
|
||||
|
||||
// Need to keep track of whether JS is being disabled
|
||||
// to prevent changing limits at runtime.
|
||||
@@ -1286,6 +1593,7 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) {
|
||||
jsMemLimitsChanged bool
|
||||
jsFileLimitsChanged bool
|
||||
jsStoreDirChanged bool
|
||||
jsLimitsUpdate *jetStreamLimitsOption
|
||||
)
|
||||
for i := 0; i < oldConfig.NumField(); i++ {
|
||||
field := oldConfig.Type().Field(i)
|
||||
@@ -1294,6 +1602,17 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) {
|
||||
if field.PkgPath != _EMPTY_ {
|
||||
continue
|
||||
}
|
||||
optName := strings.ToLower(field.Name)
|
||||
if skipTKeys && optName == "trustedkeys" {
|
||||
// TrustedOperators and TrustedKeys change is not supported. During options
|
||||
// validation, if they are both specified, a conflict error is returned.
|
||||
// If only TrustedOperators is specified, the TrustedKeys is filled with
|
||||
// the operators' signing keys. So here, if we detect that the current
|
||||
// options have operators, we don't do the trusted keys comparison, so
|
||||
// we can fail with the "not supported for TrustedOperators" config reload
|
||||
// error instead of TrustedKeys (that the user would not have set).
|
||||
continue
|
||||
}
|
||||
var (
|
||||
oldValue = oldConfig.Field(i).Interface()
|
||||
newValue = newConfig.Field(i).Interface()
|
||||
@@ -1305,7 +1624,6 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
optName := strings.ToLower(field.Name)
|
||||
// accounts and users (referencing accounts) will always differ as accounts
|
||||
// contain internal state, say locks etc..., so we don't bother here.
|
||||
// This also avoids races with atomic stats counters
|
||||
@@ -1374,7 +1692,7 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) {
|
||||
co := &clusterOption{
|
||||
newValue: newClusterOpts,
|
||||
permsChanged: !reflect.DeepEqual(newClusterOpts.Permissions, oldClusterOpts.Permissions),
|
||||
compressChanged: !compressOptsEqual(&oldClusterOpts.Compression, &newClusterOpts.Compression),
|
||||
compressChanged: !oldClusterOpts.Compression.equals(&newClusterOpts.Compression),
|
||||
}
|
||||
co.diffPoolAndAccounts(&oldClusterOpts)
|
||||
// If there are added accounts, first make sure that we can look them up.
|
||||
@@ -1445,6 +1763,11 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) {
|
||||
tmpOld.tlsConfigOpts = nil
|
||||
tmpNew.tlsConfigOpts = nil
|
||||
|
||||
// Allow TLSPinnedCerts through reload, existing connections
|
||||
// are checked in recheckPinnedCerts
|
||||
tmpOld.TLSPinnedCerts = nil
|
||||
tmpNew.TLSPinnedCerts = nil
|
||||
|
||||
// Need to do the same for remote gateways' TLS configs.
|
||||
// But we can't just set remotes' TLSConfig to nil otherwise this
|
||||
// would lose the real TLS configuration.
|
||||
@@ -1458,149 +1781,19 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) {
|
||||
field.Name, oldValue, newValue)
|
||||
}
|
||||
case "leafnode":
|
||||
// Similar to gateways
|
||||
tmpOld := oldValue.(LeafNodeOpts)
|
||||
tmpNew := newValue.(LeafNodeOpts)
|
||||
tmpOld.TLSConfig = nil
|
||||
tmpNew.TLSConfig = nil
|
||||
tmpOld.tlsConfigOpts = nil
|
||||
tmpNew.tlsConfigOpts = nil
|
||||
// We will allow TLSHandshakeFirst to be config reloaded. First,
|
||||
// we just want to detect if there was a change in the leafnodes{}
|
||||
// block, and if not, we will check the remotes.
|
||||
handshakeFirstChanged := tmpOld.TLSHandshakeFirst != tmpNew.TLSHandshakeFirst ||
|
||||
tmpOld.TLSHandshakeFirstFallback != tmpNew.TLSHandshakeFirstFallback
|
||||
// If changed, set them (in the temporary variables) to false so that the
|
||||
// rest of the comparison does not fail.
|
||||
if handshakeFirstChanged {
|
||||
tmpOld.TLSHandshakeFirst, tmpNew.TLSHandshakeFirst = false, false
|
||||
tmpOld.TLSHandshakeFirstFallback, tmpNew.TLSHandshakeFirstFallback = 0, 0
|
||||
} else if len(tmpOld.Remotes) == len(tmpNew.Remotes) {
|
||||
// Since we don't support changes in the remotes, we will do a
|
||||
// simple pass to see if there was a change of this field.
|
||||
for i := 0; i < len(tmpOld.Remotes); i++ {
|
||||
if tmpOld.Remotes[i].TLSHandshakeFirst != tmpNew.Remotes[i].TLSHandshakeFirst {
|
||||
handshakeFirstChanged = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
lno, err := getLeafNodeOptionsChanges(s, &tmpOld, &tmpNew)
|
||||
// If there was an unsupported change, we will get an error with the name
|
||||
// of the (first) field and its old and new value.
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config reload not supported for %s: %v", field.Name, err)
|
||||
}
|
||||
// We also support config reload for compression. Check if it changed before
|
||||
// blanking them out for the deep-equal check at the end.
|
||||
compressionChanged := !compressOptsEqual(&tmpOld.Compression, &tmpNew.Compression)
|
||||
if compressionChanged {
|
||||
tmpOld.Compression, tmpNew.Compression = CompressionOpts{}, CompressionOpts{}
|
||||
} else if len(tmpOld.Remotes) == len(tmpNew.Remotes) {
|
||||
// Same that for tls first check, do the remotes now.
|
||||
for i := range len(tmpOld.Remotes) {
|
||||
if !compressOptsEqual(&tmpOld.Remotes[i].Compression, &tmpNew.Remotes[i].Compression) {
|
||||
compressionChanged = true
|
||||
break
|
||||
}
|
||||
}
|
||||
// If there was an actual change...
|
||||
if lno != nil {
|
||||
diffOpts = append(diffOpts, lno)
|
||||
}
|
||||
// Check if the "disabled" option of each remote has changed.
|
||||
var disabledChanged bool
|
||||
for i := range len(tmpOld.Remotes) {
|
||||
if tmpOld.Remotes[i].Disabled != tmpNew.Remotes[i].Disabled {
|
||||
disabledChanged = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Need to do the same for remote leafnodes' TLS configs.
|
||||
// But we can't just set remotes' TLSConfig to nil otherwise this
|
||||
// would lose the real TLS configuration.
|
||||
tmpOld.Remotes = copyRemoteLNConfigForReloadCompare(tmpOld.Remotes)
|
||||
tmpNew.Remotes = copyRemoteLNConfigForReloadCompare(tmpNew.Remotes)
|
||||
|
||||
// Special check for leafnode remotes changes which are not supported right now.
|
||||
leafRemotesChanged := func(a, b LeafNodeOpts) bool {
|
||||
if len(a.Remotes) != len(b.Remotes) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check whether all remotes URLs are still the same.
|
||||
for _, oldRemote := range a.Remotes {
|
||||
var found bool
|
||||
|
||||
if oldRemote.LocalAccount == _EMPTY_ {
|
||||
oldRemote.LocalAccount = globalAccountName
|
||||
}
|
||||
|
||||
for _, newRemote := range b.Remotes {
|
||||
// Bind to global account in case not defined.
|
||||
if newRemote.LocalAccount == _EMPTY_ {
|
||||
newRemote.LocalAccount = globalAccountName
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(oldRemote, newRemote) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// First check whether remotes changed at all. If they did not,
|
||||
// skip them in the complete equal check.
|
||||
if !leafRemotesChanged(tmpOld, tmpNew) {
|
||||
tmpOld.Remotes = nil
|
||||
tmpNew.Remotes = nil
|
||||
}
|
||||
|
||||
// Special check for auth users to detect changes.
|
||||
// If anything is off will fall through and fail below.
|
||||
// If we detect they are semantically the same we nil them out
|
||||
// to pass the check below.
|
||||
if tmpOld.Users != nil || tmpNew.Users != nil {
|
||||
if len(tmpOld.Users) == len(tmpNew.Users) {
|
||||
oua := make(map[string]*User, len(tmpOld.Users))
|
||||
nua := make(map[string]*User, len(tmpOld.Users))
|
||||
for _, u := range tmpOld.Users {
|
||||
oua[u.Username] = u
|
||||
}
|
||||
for _, u := range tmpNew.Users {
|
||||
nua[u.Username] = u
|
||||
}
|
||||
same := true
|
||||
for uname, u := range oua {
|
||||
// If we can not find new one with same name, drop through to fail.
|
||||
nu, ok := nua[uname]
|
||||
if !ok {
|
||||
same = false
|
||||
break
|
||||
}
|
||||
// If username or password or account different break.
|
||||
if u.Username != nu.Username || u.Password != nu.Password || u.Account.GetName() != nu.Account.GetName() {
|
||||
same = false
|
||||
break
|
||||
}
|
||||
}
|
||||
// We can nil out here.
|
||||
if same {
|
||||
tmpOld.Users, tmpNew.Users = nil, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If there is really a change prevents reload.
|
||||
if !reflect.DeepEqual(tmpOld, tmpNew) {
|
||||
// See TODO(ik) note below about printing old/new values.
|
||||
return nil, fmt.Errorf("config reload not supported for %s: old=%v, new=%v",
|
||||
field.Name, oldValue, newValue)
|
||||
}
|
||||
|
||||
diffOpts = append(diffOpts, &leafNodeOption{
|
||||
tlsFirstChanged: handshakeFirstChanged,
|
||||
compressionChanged: compressionChanged,
|
||||
disabledChanged: disabledChanged,
|
||||
})
|
||||
case "jetstream":
|
||||
new := newValue.(bool)
|
||||
old := oldValue.(bool)
|
||||
@@ -1636,26 +1829,35 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) {
|
||||
fromSet = !fromUnset
|
||||
toUnset = new == -1
|
||||
toSet = !toUnset
|
||||
increased = fromSet && toSet && new > old
|
||||
)
|
||||
if jsEnabled && modified {
|
||||
// Cannot change limits from dynamic storage at runtime.
|
||||
switch {
|
||||
case increased:
|
||||
// Allowed to increase, but not decrease.
|
||||
if jsLimitsUpdate == nil {
|
||||
jsLimitsUpdate = &jetStreamLimitsOption{}
|
||||
diffOpts = append(diffOpts, jsLimitsUpdate)
|
||||
}
|
||||
if optName == "jetstreammaxmemory" {
|
||||
jsLimitsUpdate.newMaxMemory = new
|
||||
} else {
|
||||
jsLimitsUpdate.newMaxStore = new
|
||||
}
|
||||
case fromSet && toUnset:
|
||||
// Limits changed but it may mean that JS is being disabled,
|
||||
// keep track of the change and error in case it is not.
|
||||
switch optName {
|
||||
case "jetstreammaxmemory":
|
||||
if optName == "jetstreammaxmemory" {
|
||||
jsMemLimitsChanged = true
|
||||
case "jetstreammaxstore":
|
||||
} else {
|
||||
jsFileLimitsChanged = true
|
||||
default:
|
||||
return nil, fmt.Errorf("config reload not supported for jetstream max memory and store")
|
||||
}
|
||||
case fromUnset && toSet:
|
||||
// Prevent changing from dynamic max memory / file at runtime.
|
||||
return nil, fmt.Errorf("config reload not supported for jetstream dynamic max memory and store")
|
||||
default:
|
||||
return nil, fmt.Errorf("config reload not supported for jetstream max memory and store")
|
||||
return nil, fmt.Errorf("config reload not supported for decreasing jetstream max memory and store")
|
||||
}
|
||||
}
|
||||
case "jetstreammetacompact", "jetstreammetacompactsize", "jetstreammetacompactsync":
|
||||
@@ -1801,32 +2003,6 @@ func copyRemoteGWConfigsWithoutTLSConfig(current []*RemoteGatewayOpts) []*Remote
|
||||
return rgws
|
||||
}
|
||||
|
||||
func copyRemoteLNConfigForReloadCompare(current []*RemoteLeafOpts) []*RemoteLeafOpts {
|
||||
l := len(current)
|
||||
if l == 0 {
|
||||
return nil
|
||||
}
|
||||
rlns := make([]*RemoteLeafOpts, 0, l)
|
||||
for _, rcfg := range current {
|
||||
cp := *rcfg
|
||||
cp.TLSConfig = nil
|
||||
cp.tlsConfigOpts = nil
|
||||
cp.TLSHandshakeFirst = false
|
||||
// This is set only when processing a CONNECT, so reset here so that we
|
||||
// don't fail the DeepEqual comparison.
|
||||
cp.TLS = false
|
||||
// For now, remove DenyImports/Exports since those get modified at runtime
|
||||
// to add JS APIs.
|
||||
cp.DenyImports, cp.DenyExports = nil, nil
|
||||
// Remove compression mode
|
||||
cp.Compression = CompressionOpts{}
|
||||
// Reset disabled status
|
||||
cp.Disabled = false
|
||||
rlns = append(rlns, &cp)
|
||||
}
|
||||
return rlns
|
||||
}
|
||||
|
||||
func (s *Server) applyOptions(ctx *reloadContext, opts []option) {
|
||||
var (
|
||||
reloadLogging = false
|
||||
@@ -1896,15 +2072,12 @@ func (s *Server) applyOptions(ctx *reloadContext, opts []option) {
|
||||
s.sendStatszUpdate()
|
||||
}
|
||||
|
||||
// For remote gateways and leafnodes, make sure that their TLS configuration
|
||||
// For remote gateways, make sure that their TLS configuration
|
||||
// is updated (since the config is "captured" early and changes would otherwise
|
||||
// not be visible).
|
||||
if s.gateway.enabled {
|
||||
s.gateway.updateRemotesTLSConfig(newOpts)
|
||||
}
|
||||
if len(newOpts.LeafNode.Remotes) > 0 {
|
||||
s.updateRemoteLeafNodesTLSConfig(newOpts)
|
||||
}
|
||||
|
||||
// Always restart OCSP monitoring on reload.
|
||||
if err := s.reloadOCSP(); err != nil {
|
||||
@@ -2650,3 +2823,38 @@ func diffProxiesTrustedKeys(old, new []*ProxyConfig) ([]string, []string) {
|
||||
}
|
||||
return add, del
|
||||
}
|
||||
|
||||
// This function calls `reflect.DeepEqual` on all public fields that are
|
||||
// not part of the `ignoreFields` list. If they are all equal, returns nil,
|
||||
// otherwise returns an error that will contain the name of the first field
|
||||
// that fails the comparison, along with its old and new values.
|
||||
func checkConfigsEqual(c1, c2 any, ignoreFields []string) error {
|
||||
oldConfig := reflect.ValueOf(c1).Elem()
|
||||
newConfig := reflect.ValueOf(c2).Elem()
|
||||
for i := 0; i < oldConfig.NumField(); i++ {
|
||||
field := oldConfig.Type().Field(i)
|
||||
// field.PkgPath is empty for exported fields, and is not for unexported ones.
|
||||
// We skip the unexported fields.
|
||||
if field.PkgPath != _EMPTY_ {
|
||||
continue
|
||||
}
|
||||
// If it is in the set of fields to ignore, move to the next.
|
||||
// We expect the number of ignore fields to be small.
|
||||
var ignored bool
|
||||
for _, f := range ignoreFields {
|
||||
if f == field.Name {
|
||||
ignored = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if ignored {
|
||||
continue
|
||||
}
|
||||
oldValue := oldConfig.Field(i).Interface()
|
||||
newValue := newConfig.Field(i).Interface()
|
||||
if !reflect.DeepEqual(oldValue, newValue) {
|
||||
return fmt.Errorf("field %q: old=%v, new=%v", field.Name, oldValue, newValue)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
+126
-13
@@ -19,6 +19,7 @@ import (
|
||||
"io"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nats-server/v2/server/thw"
|
||||
@@ -77,6 +78,18 @@ func (ms *MsgScheduling) init(seq uint64, subj string, ts int64) {
|
||||
delete(ms.inflight, subj)
|
||||
}
|
||||
|
||||
func (ms *MsgScheduling) update(subj string, ts int64) {
|
||||
if sched, ok := ms.schedules[subj]; ok {
|
||||
// Remove and add separately, it's for the same sequence, but if replicated
|
||||
// this server could not know the previous timestamp yet.
|
||||
ms.ttls.Remove(sched.seq, sched.ts)
|
||||
ms.ttls.Add(sched.seq, ts)
|
||||
sched.ts = ts
|
||||
delete(ms.inflight, subj)
|
||||
ms.resetTimer()
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *MsgScheduling) markInflight(subj string) {
|
||||
if _, ok := ms.schedules[subj]; ok {
|
||||
ms.inflight[subj] = struct{}{}
|
||||
@@ -90,8 +103,7 @@ func (ms *MsgScheduling) isInflight(subj string) bool {
|
||||
|
||||
func (ms *MsgScheduling) remove(seq uint64) {
|
||||
if subj, ok := ms.seqToSubj[seq]; ok {
|
||||
delete(ms.seqToSubj, seq)
|
||||
delete(ms.schedules, subj)
|
||||
ms.removeSubject(subj)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,6 +112,7 @@ func (ms *MsgScheduling) removeSubject(subj string) {
|
||||
ms.ttls.Remove(sched.seq, sched.ts)
|
||||
delete(ms.schedules, subj)
|
||||
delete(ms.seqToSubj, sched.seq)
|
||||
delete(ms.inflight, subj)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,11 +155,12 @@ func (ms *MsgScheduling) resetTimer() {
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *MsgScheduling) getScheduledMessages(loadMsg func(seq uint64, smv *StoreMsg) *StoreMsg) []*inMsg {
|
||||
func (ms *MsgScheduling) getScheduledMessages(loadMsg func(seq uint64, smv *StoreMsg) *StoreMsg, loadLast func(subj string, smv *StoreMsg) *StoreMsg) []*inMsg {
|
||||
var (
|
||||
smv StoreMsg
|
||||
sm *StoreMsg
|
||||
msgs []*inMsg
|
||||
smv StoreMsg
|
||||
srcSmv StoreMsg
|
||||
sm *StoreMsg
|
||||
msgs []*inMsg
|
||||
)
|
||||
ms.ttls.ExpireTasks(func(seq uint64, ts int64) bool {
|
||||
// Need to grab the message for the specified sequence, and check
|
||||
@@ -155,10 +169,26 @@ func (ms *MsgScheduling) getScheduledMessages(loadMsg func(seq uint64, smv *Stor
|
||||
if sm != nil {
|
||||
// If already inflight, don't duplicate a scheduled message. The stream could
|
||||
// be replicated and the scheduled message could take some time to propagate.
|
||||
if ms.isInflight(sm.subj) {
|
||||
subj := sm.subj
|
||||
if ms.isInflight(subj) {
|
||||
return false
|
||||
}
|
||||
// Validate the contents are correct if not, we just remove it from THW.
|
||||
pattern := bytesToString(sliceHeader(JSSchedulePattern, sm.hdr))
|
||||
if pattern == _EMPTY_ {
|
||||
ms.remove(seq)
|
||||
return true
|
||||
}
|
||||
loc, apiErr := loadMessageScheduleLocation(sm.hdr)
|
||||
if apiErr != nil {
|
||||
ms.remove(seq)
|
||||
return true
|
||||
}
|
||||
next, repeat, ok := parseMsgSchedule(pattern, loc, ts)
|
||||
if !ok {
|
||||
ms.remove(seq)
|
||||
return true
|
||||
}
|
||||
ttl, ok := getMessageScheduleTTL(sm.hdr)
|
||||
if !ok {
|
||||
ms.remove(seq)
|
||||
@@ -169,27 +199,43 @@ func (ms *MsgScheduling) getScheduledMessages(loadMsg func(seq uint64, smv *Stor
|
||||
ms.remove(seq)
|
||||
return true
|
||||
}
|
||||
rollup := getMessageScheduleRollup(sm.hdr)
|
||||
source := getMessageScheduleSource(sm.hdr)
|
||||
if source != _EMPTY_ {
|
||||
// Fall back to the scheduled message's own content if the source has no last message.
|
||||
if srcSm := loadLast(source, &srcSmv); srcSm != nil {
|
||||
sm = srcSm
|
||||
}
|
||||
}
|
||||
|
||||
// Copy, as this is retrieved directly from storage, and we'll need to keep hold of this for some time.
|
||||
// And in the case of headers, we'll copy all of them, but make changes.
|
||||
hdr, msg := copyBytes(sm.hdr), copyBytes(sm.msg)
|
||||
|
||||
// Strip headers specific to the schedule.
|
||||
hdr = removeHeaderIfPresent(hdr, JSSchedulePattern)
|
||||
hdr = removeHeaderIfPrefixPresent(hdr, "Nats-Schedule-")
|
||||
// Strip headers specific to message scheduling.
|
||||
// Covers Nats-Schedule, Nats-Schedule-*, and Nats-Scheduler.
|
||||
hdr = removeHeaderIfPrefixPresent(hdr, "Nats-Schedule")
|
||||
// Strip headers that could prevent persisting this scheduled message.
|
||||
hdr = removeHeaderIfPrefixPresent(hdr, "Nats-Expected-")
|
||||
hdr = removeHeaderIfPresent(hdr, JSMsgId)
|
||||
hdr = removeHeaderIfPresent(hdr, JSMessageTTL)
|
||||
hdr = removeHeaderIfPresent(hdr, JSMsgRollup)
|
||||
|
||||
// Add headers for the scheduled message.
|
||||
hdr = genHeader(hdr, JSScheduler, sm.subj)
|
||||
hdr = genHeader(hdr, JSScheduleNext, JSScheduleNextPurge) // Purge the schedule message itself.
|
||||
hdr = genHeader(hdr, JSScheduler, subj)
|
||||
if !repeat {
|
||||
hdr = genHeader(hdr, JSScheduleNext, JSScheduleNextPurge) // Purge the schedule message itself.
|
||||
} else {
|
||||
hdr = genHeader(hdr, JSScheduleNext, next.Format(time.RFC3339)) // Next time the schedule fires.
|
||||
}
|
||||
if ttl != _EMPTY_ {
|
||||
hdr = genHeader(hdr, JSMessageTTL, ttl)
|
||||
}
|
||||
if rollup != _EMPTY_ {
|
||||
hdr = genHeader(hdr, JSMsgRollup, rollup)
|
||||
}
|
||||
msgs = append(msgs, &inMsg{seq: seq, subj: target, hdr: hdr, msg: msg})
|
||||
ms.markInflight(sm.subj)
|
||||
ms.markInflight(subj)
|
||||
return false
|
||||
}
|
||||
ms.remove(seq)
|
||||
@@ -261,3 +307,70 @@ func (ms *MsgScheduling) decode(b []byte) (uint64, error) {
|
||||
}
|
||||
return stamp, nil
|
||||
}
|
||||
|
||||
// parseMsgSchedule parses a message schedule pattern and returns the time
|
||||
// to fire, whether it is a repeating schedule, and whether the pattern was valid.
|
||||
func parseMsgSchedule(pattern string, loc *time.Location, ts int64) (time.Time, bool, bool) {
|
||||
if pattern == _EMPTY_ {
|
||||
return time.Time{}, false, true
|
||||
}
|
||||
// Exact time.
|
||||
if strings.HasPrefix(pattern, "@at ") {
|
||||
// Time zone is not supported for @at.
|
||||
if loc != nil {
|
||||
return time.Time{}, false, false
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, pattern[4:])
|
||||
return t, false, err == nil
|
||||
}
|
||||
// Repeating on a simple interval.
|
||||
if strings.HasPrefix(pattern, "@every ") {
|
||||
// Time zone is not supported for @every.
|
||||
if loc != nil {
|
||||
return time.Time{}, false, false
|
||||
}
|
||||
dur, err := time.ParseDuration(pattern[7:])
|
||||
if err != nil {
|
||||
return time.Time{}, false, false
|
||||
}
|
||||
// Only allow intervals of at least a second.
|
||||
if dur.Seconds() < 1 {
|
||||
return time.Time{}, false, false
|
||||
}
|
||||
// If this schedule would trigger multiple times, for example after a restart, skip ahead and only fire once.
|
||||
next := time.Unix(0, ts).UTC().Round(time.Second).Add(dur)
|
||||
if now := time.Now().UTC(); next.Before(now) {
|
||||
next = now.Round(time.Second).Add(dur)
|
||||
}
|
||||
return next, true, true
|
||||
}
|
||||
|
||||
// Predefined schedules for cron.
|
||||
switch pattern {
|
||||
case "@yearly", "@annually":
|
||||
pattern = "0 0 0 1 1 *"
|
||||
case "@monthly":
|
||||
pattern = "0 0 0 1 * *"
|
||||
case "@weekly":
|
||||
pattern = "0 0 0 * * 0"
|
||||
case "@daily", "@midnight":
|
||||
pattern = "0 0 0 * * *"
|
||||
case "@hourly":
|
||||
pattern = "0 0 * * * *"
|
||||
}
|
||||
|
||||
// Parse the cron pattern.
|
||||
next, err := parseCron(pattern, loc, ts)
|
||||
if err != nil {
|
||||
return time.Time{}, false, false
|
||||
}
|
||||
// If this schedule would trigger multiple times, for example after a restart, skip ahead and only fire once.
|
||||
if now := time.Now().UTC(); next.Before(now) {
|
||||
ts = now.Round(time.Second).UnixNano()
|
||||
next, err = parseCron(pattern, loc, ts)
|
||||
if err != nil {
|
||||
return time.Time{}, false, false
|
||||
}
|
||||
}
|
||||
return next, true, true
|
||||
}
|
||||
|
||||
+7
-36
@@ -31,7 +31,6 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
@@ -234,7 +233,8 @@ type Server struct {
|
||||
resolver netResolver
|
||||
dialTimeout time.Duration
|
||||
}
|
||||
leafRemoteCfgs []*leafNodeCfg
|
||||
leafRemoteCfgs map[*leafNodeCfg]struct{}
|
||||
rmLeafRemoteCfgs map[string]*leafNodeCfg
|
||||
leafRemoteAccounts sync.Map
|
||||
leafNodeEnabled bool
|
||||
leafDisableConnect bool // Used in test only
|
||||
@@ -367,7 +367,8 @@ type Server struct {
|
||||
syncOutSem chan struct{}
|
||||
|
||||
// Queue to process JS API requests that come from routes (or gateways)
|
||||
jsAPIRoutedReqs *ipQueue[*jsAPIRoutedReq]
|
||||
jsAPIRoutedReqs *ipQueue[*jsAPIRoutedReq]
|
||||
jsAPIRoutedInfoReqs *ipQueue[*jsAPIRoutedReq]
|
||||
|
||||
// Delayed API responses.
|
||||
delayedAPIResponses *ipQueue[*delayedAPIResponse]
|
||||
@@ -645,32 +646,6 @@ func selectS2AutoModeBasedOnRTT(rtt time.Duration, rttThresholds []time.Duration
|
||||
return CompressionS2Best
|
||||
}
|
||||
|
||||
func compressOptsEqual(c1, c2 *CompressionOpts) bool {
|
||||
if c1 == c2 {
|
||||
return true
|
||||
}
|
||||
if (c1 == nil && c2 != nil) || (c1 != nil && c2 == nil) {
|
||||
return false
|
||||
}
|
||||
if c1.Mode != c2.Mode {
|
||||
return false
|
||||
}
|
||||
// For s2_auto, if one has an empty RTTThresholds, it is equivalent
|
||||
// to the defaultCompressionS2AutoRTTThresholds array, so compare with that.
|
||||
if c1.Mode == CompressionS2Auto {
|
||||
if len(c1.RTTThresholds) == 0 && !reflect.DeepEqual(c2.RTTThresholds, defaultCompressionS2AutoRTTThresholds) {
|
||||
return false
|
||||
}
|
||||
if len(c2.RTTThresholds) == 0 && !reflect.DeepEqual(c1.RTTThresholds, defaultCompressionS2AutoRTTThresholds) {
|
||||
return false
|
||||
}
|
||||
if !reflect.DeepEqual(c1.RTTThresholds, c2.RTTThresholds) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Returns an array of s2 WriterOption based on the route compression mode.
|
||||
// So far we return a single option, but this way we can call s2.NewWriter()
|
||||
// with a nil []s2.WriterOption, but not with a nil s2.WriterOption, so
|
||||
@@ -1956,12 +1931,7 @@ func (s *Server) registerAccount(acc *Account) *Account {
|
||||
// Helper to set the sublist based on preferences.
|
||||
func (s *Server) setAccountSublist(acc *Account) {
|
||||
if acc != nil && acc.sl == nil {
|
||||
opts := s.getOpts()
|
||||
if opts != nil && opts.NoSublistCache {
|
||||
acc.sl = NewSublistNoCache()
|
||||
} else {
|
||||
acc.sl = NewSublistWithCache()
|
||||
}
|
||||
acc.sl = NewSublistForServer(s)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2293,6 +2263,7 @@ func (s *Server) Start() {
|
||||
s.Noticef(" Node: %s", getHash(s.info.Name))
|
||||
}
|
||||
s.Noticef(" ID: %s", s.info.ID)
|
||||
s.printFeatureFlags(opts)
|
||||
|
||||
defer s.Noticef("Server is ready")
|
||||
|
||||
@@ -3586,7 +3557,7 @@ func (s *Server) saveClosedClient(c *client, nc net.Conn, subs map[string]*subsc
|
||||
if c.acc != nil && c.acc.Name != globalAccountName {
|
||||
cc.acc = c.acc.Name
|
||||
}
|
||||
cc.JWT = c.opts.JWT
|
||||
cc.JWT = redactBearerJWT(c.opts.JWT)
|
||||
cc.IssuerKey = issuerForClient(c)
|
||||
cc.Tags = c.tags
|
||||
cc.NameTag = c.nameTag
|
||||
|
||||
+20
-19
@@ -92,15 +92,15 @@ type ProcessJetStreamMsgHandler func(*inMsg)
|
||||
|
||||
type StreamStore interface {
|
||||
StoreMsg(subject string, hdr, msg []byte, ttl int64) (uint64, int64, error)
|
||||
StoreRawMsg(subject string, hdr, msg []byte, seq uint64, ts int64, ttl int64) error
|
||||
StoreRawMsg(subject string, hdr, msg []byte, seq uint64, ts int64, ttl int64, discardNewCheck bool) error
|
||||
SkipMsg(seq uint64) (uint64, error)
|
||||
SkipMsgs(seq uint64, num uint64) error
|
||||
FlushAllPending()
|
||||
FlushAllPending() error
|
||||
LoadMsg(seq uint64, sm *StoreMsg) (*StoreMsg, error)
|
||||
LoadNextMsg(filter string, wc bool, start uint64, smp *StoreMsg) (sm *StoreMsg, skip uint64, err error)
|
||||
LoadNextMsgMulti(sl *gsl.SimpleSublist, start uint64, smp *StoreMsg) (sm *StoreMsg, skip uint64, err error)
|
||||
LoadLastMsg(subject string, sm *StoreMsg) (*StoreMsg, error)
|
||||
LoadPrevMsg(start uint64, smp *StoreMsg) (sm *StoreMsg, err error)
|
||||
LoadPrevMsg(filter string, wc bool, start uint64, smp *StoreMsg) (sm *StoreMsg, skip uint64, err error)
|
||||
LoadPrevMsgMulti(sl *gsl.SimpleSublist, start uint64, smp *StoreMsg) (sm *StoreMsg, skip uint64, err error)
|
||||
RemoveMsg(seq uint64) (bool, error)
|
||||
EraseMsg(seq uint64) (bool, error)
|
||||
@@ -109,7 +109,7 @@ type StreamStore interface {
|
||||
Compact(seq uint64) (uint64, error)
|
||||
Truncate(seq uint64) error
|
||||
GetSeqFromTime(t time.Time) uint64
|
||||
FilteredState(seq uint64, subject string) SimpleState
|
||||
FilteredState(seq uint64, subject string) (SimpleState, error)
|
||||
SubjectsState(filterSubject string) map[string]SimpleState
|
||||
SubjectsTotals(filterSubject string) map[string]uint64
|
||||
AllLastSeqs() ([]uint64, error)
|
||||
@@ -120,7 +120,7 @@ type StreamStore interface {
|
||||
State() StreamState
|
||||
FastState(*StreamState)
|
||||
EncodedStreamState(failed uint64) (enc []byte, err error)
|
||||
SyncDeleted(dbs DeleteBlocks)
|
||||
SyncDeleted(dbs DeleteBlocks) error
|
||||
Type() StorageType
|
||||
RegisterStorageUpdates(StorageUpdateHandler)
|
||||
RegisterStorageRemoveMsg(StorageRemoveMsgHandler)
|
||||
@@ -360,11 +360,13 @@ func (dbs DeleteBlocks) NumDeleted() (total uint64) {
|
||||
type ConsumerStore interface {
|
||||
SetStarting(sseq uint64) error
|
||||
UpdateStarting(sseq uint64)
|
||||
Reset(sseq uint64) error
|
||||
HasState() bool
|
||||
UpdateDelivered(dseq, sseq, dc uint64, ts int64) error
|
||||
UpdateAcks(dseq, sseq uint64) error
|
||||
UpdateConfig(cfg *ConsumerConfig) error
|
||||
Update(*ConsumerState) error
|
||||
ForceUpdate(*ConsumerState) error
|
||||
State() (*ConsumerState, error)
|
||||
BorrowState() (*ConsumerState, error)
|
||||
EncodedState() ([]byte, error)
|
||||
@@ -464,13 +466,6 @@ type Pending struct {
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
// TemplateStore stores templates.
|
||||
// Deprecated: stream templates are deprecated and will be removed in a future version.
|
||||
type TemplateStore interface {
|
||||
Store(*streamTemplate) error
|
||||
Delete(*streamTemplate) error
|
||||
}
|
||||
|
||||
const (
|
||||
limitsPolicyJSONString = `"limits"`
|
||||
interestPolicyJSONString = `"interest"`
|
||||
@@ -602,15 +597,17 @@ func (st *StorageType) UnmarshalJSON(data []byte) error {
|
||||
}
|
||||
|
||||
const (
|
||||
ackNonePolicyJSONString = `"none"`
|
||||
ackAllPolicyJSONString = `"all"`
|
||||
ackExplicitPolicyJSONString = `"explicit"`
|
||||
ackNonePolicyJSONString = `"none"`
|
||||
ackAllPolicyJSONString = `"all"`
|
||||
ackExplicitPolicyJSONString = `"explicit"`
|
||||
ackFlowControlPolicyJSONString = `"flow_control"`
|
||||
)
|
||||
|
||||
var (
|
||||
ackNonePolicyJSONBytes = []byte(ackNonePolicyJSONString)
|
||||
ackAllPolicyJSONBytes = []byte(ackAllPolicyJSONString)
|
||||
ackExplicitPolicyJSONBytes = []byte(ackExplicitPolicyJSONString)
|
||||
ackNonePolicyJSONBytes = []byte(ackNonePolicyJSONString)
|
||||
ackAllPolicyJSONBytes = []byte(ackAllPolicyJSONString)
|
||||
ackExplicitPolicyJSONBytes = []byte(ackExplicitPolicyJSONString)
|
||||
ackFlowControlPolicyJSONBytes = []byte(ackFlowControlPolicyJSONString)
|
||||
)
|
||||
|
||||
func (ap AckPolicy) MarshalJSON() ([]byte, error) {
|
||||
@@ -621,6 +618,8 @@ func (ap AckPolicy) MarshalJSON() ([]byte, error) {
|
||||
return ackAllPolicyJSONBytes, nil
|
||||
case AckExplicit:
|
||||
return ackExplicitPolicyJSONBytes, nil
|
||||
case AckFlowControl:
|
||||
return ackFlowControlPolicyJSONBytes, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("can not marshal %v", ap)
|
||||
}
|
||||
@@ -634,6 +633,8 @@ func (ap *AckPolicy) UnmarshalJSON(data []byte) error {
|
||||
*ap = AckAll
|
||||
case ackExplicitPolicyJSONString:
|
||||
*ap = AckExplicit
|
||||
case ackFlowControlPolicyJSONString:
|
||||
*ap = AckFlowControl
|
||||
default:
|
||||
return fmt.Errorf("can not unmarshal %q", data)
|
||||
}
|
||||
@@ -741,7 +742,7 @@ func isOutOfSpaceErr(err error) bool {
|
||||
var errFirstSequenceMismatch = errors.New("first sequence mismatch")
|
||||
|
||||
func isClusterResetErr(err error) bool {
|
||||
return err == errLastSeqMismatch || err == ErrStoreEOF || err == errFirstSequenceMismatch || errors.Is(err, errCatchupAbortedNoLeader) || err == errCatchupTooManyRetries
|
||||
return err == errLastSeqMismatch || err == ErrStoreEOF || err == errFirstSequenceMismatch || errors.Is(err, errCatchupAbortedNoLeader) || err == errCatchupTooManyRetries || err == errAlreadyLeader
|
||||
}
|
||||
|
||||
// Copy all fields.
|
||||
|
||||
+1668
-469
File diff suppressed because it is too large
Load Diff
+12
@@ -121,6 +121,18 @@ func NewSublist(enableCache bool) *Sublist {
|
||||
return &Sublist{root: newLevel()}
|
||||
}
|
||||
|
||||
// NewSublistForServer will create a default sublist with caching enabled determined
|
||||
// by the server options.
|
||||
func NewSublistForServer(srv *Server) *Sublist {
|
||||
if srv == nil {
|
||||
return NewSublistNoCache() // Probably just unit tests.
|
||||
}
|
||||
if opts := srv.getOpts(); opts != nil {
|
||||
return NewSublist(!opts.NoSublistCache)
|
||||
}
|
||||
return NewSublistNoCache()
|
||||
}
|
||||
|
||||
// NewSublistWithCache will create a default sublist with caching enabled.
|
||||
func NewSublistWithCache() *Sublist {
|
||||
return NewSublist(true)
|
||||
|
||||
+5
-2
@@ -155,8 +155,11 @@ func (hw *HashWheel) expireTasks(ts int64, callback func(seq uint64, expires int
|
||||
slotLowest := int64(math.MaxInt64)
|
||||
for seq, expires := range s.entries {
|
||||
if expires <= ts && callback(seq, expires) {
|
||||
delete(s.entries, seq)
|
||||
hw.count--
|
||||
// Only remove if not done so already by the callback.
|
||||
if _, ok := s.entries[seq]; ok {
|
||||
delete(s.entries, seq)
|
||||
hw.count--
|
||||
}
|
||||
continue
|
||||
}
|
||||
if expires < slotLowest {
|
||||
|
||||
+8
@@ -1,3 +1,11 @@
|
||||
## 1.40.0
|
||||
|
||||
We're adopting a new release strategy to minimize dependency bloat in projects that consume Gomega. It is a limitation of the go mod toolchain that _test_ subdependencies of your project's direct dependencies get pulled in as *indirect* dependencies. In the case of Gomega, this ends up pulling in all of Ginkgo into your `go.mod` even if you are only using Gomega (Gomega uses Ginkgo for its own tests).
|
||||
|
||||
Going forward, releases will strip out all tests, tidy up the `go.mod` and then push this stripped down version to a new `master-lite` branch. These stripped-down versions will receive the `vx.y.z` git tag and will be picked up by the go toolchain.
|
||||
|
||||
Please open an issue if this new release process causes unexpected changes for your projects.
|
||||
|
||||
## 1.39.1
|
||||
|
||||
Update all dependencies. This auto-updated the required version of Go to 1.24, consistent with the fact that Go 1.23 has been out of support for almost six months.
|
||||
|
||||
+1
-1
@@ -22,7 +22,7 @@ import (
|
||||
"github.com/onsi/gomega/types"
|
||||
)
|
||||
|
||||
const GOMEGA_VERSION = "1.39.1"
|
||||
const GOMEGA_VERSION = "1.40.0"
|
||||
|
||||
const nilGomegaPanic = `You are trying to make an assertion, but haven't registered Gomega's fail handler.
|
||||
If you're using Ginkgo then you probably forgot to put your assertion in an It().
|
||||
|
||||
Generated
Vendored
-10
@@ -744,16 +744,6 @@ func (s *Service) Delete(ctx context.Context, req *provider.DeleteRequest) (*pro
|
||||
}
|
||||
|
||||
ctx = ctxpkg.ContextSetLockID(ctx, req.LockId)
|
||||
|
||||
// check DeleteRequest for any known opaque properties.
|
||||
// FIXME these should be part of the DeleteRequest object
|
||||
if req.Opaque != nil {
|
||||
if _, ok := req.Opaque.Map["deleting_shared_resource"]; ok {
|
||||
// it is a binary key; its existence signals true. Although, do not assume.
|
||||
ctx = appctx.WithDeletingSharedResource(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
md, err := s.Storage.GetMD(ctx, req.Ref, []string{}, []string{"id", "status"})
|
||||
if err != nil {
|
||||
return &provider.DeleteResponse{
|
||||
|
||||
Generated
Vendored
+10
-4
@@ -182,9 +182,12 @@ func (s *service) GetUser(ctx context.Context, req *userpb.GetUserRequest) (*use
|
||||
user, err := s.usermgr.GetUser(ctx, req.UserId, req.SkipFetchingUserGroups)
|
||||
if err != nil {
|
||||
res := &userpb.GetUserResponse{}
|
||||
if _, ok := err.(errtypes.NotFound); ok {
|
||||
switch err.(type) {
|
||||
case errtypes.NotFound:
|
||||
res.Status = status.NewNotFound(ctx, "user not found")
|
||||
} else {
|
||||
case errtypes.Unavailable:
|
||||
res.Status = status.NewUnavailable(ctx, "user provider temporarily unavailable")
|
||||
default:
|
||||
res.Status = status.NewInternal(ctx, "error getting user")
|
||||
}
|
||||
return res, nil
|
||||
@@ -205,9 +208,12 @@ func (s *service) GetUserByClaim(ctx context.Context, req *userpb.GetUserByClaim
|
||||
user, err := s.usermgr.GetUserByClaim(ctx, req.Claim, req.Value, tenantID, req.SkipFetchingUserGroups)
|
||||
if err != nil {
|
||||
res := &userpb.GetUserByClaimResponse{}
|
||||
if _, ok := err.(errtypes.NotFound); ok {
|
||||
switch err.(type) {
|
||||
case errtypes.NotFound:
|
||||
res.Status = status.NewNotFound(ctx, fmt.Sprintf("user not found %s %s", req.Claim, req.Value))
|
||||
} else {
|
||||
case errtypes.Unavailable:
|
||||
res.Status = status.NewUnavailable(ctx, "user provider temporarily unavailable")
|
||||
default:
|
||||
res.Status = status.NewInternal(ctx, "error getting user by claim")
|
||||
}
|
||||
return res, nil
|
||||
|
||||
vendor/github.com/opencloud-eu/reva/v2/internal/grpc/services/usershareprovider/usershareprovider.go
Generated
Vendored
+4
-4
@@ -84,9 +84,9 @@ type service struct {
|
||||
allowedPathsForShares []*regexp.Regexp
|
||||
}
|
||||
|
||||
func getShareManager(c *config) (share.Manager, error) {
|
||||
func getShareManager(c *config, logger *zerolog.Logger) (share.Manager, error) {
|
||||
if f, ok := registry.NewFuncs[c.Driver]; ok {
|
||||
return f(c.Drivers[c.Driver])
|
||||
return f(c.Drivers[c.Driver], logger)
|
||||
}
|
||||
return nil, errtypes.NotFound("driver not found: " + c.Driver)
|
||||
}
|
||||
@@ -114,7 +114,7 @@ func parseConfig(m map[string]interface{}) (*config, error) {
|
||||
}
|
||||
|
||||
// New creates a new user share provider svc initialized from defaults
|
||||
func NewDefault(m map[string]interface{}, ss *grpc.Server, _ *zerolog.Logger) (rgrpc.Service, error) {
|
||||
func NewDefault(m map[string]any, ss *grpc.Server, logger *zerolog.Logger) (rgrpc.Service, error) {
|
||||
|
||||
c, err := parseConfig(m)
|
||||
if err != nil {
|
||||
@@ -123,7 +123,7 @@ func NewDefault(m map[string]interface{}, ss *grpc.Server, _ *zerolog.Logger) (r
|
||||
|
||||
c.init()
|
||||
|
||||
sm, err := getShareManager(c)
|
||||
sm, err := getShareManager(c, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Generated
Vendored
+7
@@ -155,6 +155,13 @@ func (s *svc) handleProppatch(ctx context.Context, w http.ResponseWriter, r *htt
|
||||
}
|
||||
for j := range patches[i].Props {
|
||||
propNameXML := patches[i].Props[j].XMLName
|
||||
|
||||
// favorites are now managed by the Graph API and can no longer be set using PROPPATCH. To avoid confusion, we return a 403 Forbidden when clients try to set the oc:favorites property
|
||||
if propNameXML.Local == "favorite" {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
// don't use path.Join. It removes the double slash! concatenate with a /
|
||||
key := fmt.Sprintf("%s/%s", patches[i].Props[j].XMLName.Space, patches[i].Props[j].XMLName.Local)
|
||||
value := string(patches[i].Props[j].InnerXML)
|
||||
|
||||
-10
@@ -27,16 +27,6 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
// deletingSharedResource flags to a storage a shared resource is being deleted not by the owner.
|
||||
type deletingSharedResource struct{}
|
||||
|
||||
func WithDeletingSharedResource(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, deletingSharedResource{}, struct{}{})
|
||||
}
|
||||
func DeletingSharedResourceFromContext(ctx context.Context) bool {
|
||||
return ctx.Value(deletingSharedResource{}) != nil
|
||||
}
|
||||
|
||||
// WithLogger returns a context with an associated logger.
|
||||
func WithLogger(ctx context.Context, l *zerolog.Logger) context.Context {
|
||||
return l.WithContext(ctx)
|
||||
|
||||
+21
@@ -203,6 +203,15 @@ func (e TooEarly) Error() string { return "error: too early: " + string(e) }
|
||||
// IsTooEarly implements the IsTooEarly interface.
|
||||
func (e TooEarly) IsTooEarly() {}
|
||||
|
||||
// Unavailable is the error to use when a backend service (e.g. LDAP, database) is
|
||||
// temporarily unreachable. Callers should treat this as a transient failure and retry.
|
||||
type Unavailable string
|
||||
|
||||
func (e Unavailable) Error() string { return "error: unavailable: " + string(e) }
|
||||
|
||||
// IsUnavailable implements the IsUnavailable interface.
|
||||
func (e Unavailable) IsUnavailable() {}
|
||||
|
||||
// IsNotFound is the interface to implement
|
||||
// to specify that a resource is not found.
|
||||
type IsNotFound interface {
|
||||
@@ -293,6 +302,12 @@ type IsTooEarly interface {
|
||||
IsTooEarly()
|
||||
}
|
||||
|
||||
// IsUnavailable is the interface to implement to specify that a backend service is
|
||||
// temporarily unavailable and the caller should retry.
|
||||
type IsUnavailable interface {
|
||||
IsUnavailable()
|
||||
}
|
||||
|
||||
// NewErrtypeFromStatus maps a rpc status to an errtype
|
||||
func NewErrtypeFromStatus(status *rpc.Status) error {
|
||||
switch status.Code {
|
||||
@@ -329,6 +344,8 @@ func NewErrtypeFromStatus(status *rpc.Status) error {
|
||||
return BadRequest(status.Message)
|
||||
case rpc.Code_CODE_TOO_EARLY:
|
||||
return TooEarly(status.Message)
|
||||
case rpc.Code_CODE_UNAVAILABLE:
|
||||
return Unavailable(status.Message)
|
||||
default:
|
||||
return InternalError(status.Message)
|
||||
}
|
||||
@@ -363,6 +380,8 @@ func NewErrtypeFromHTTPStatusCode(code int, message string) error {
|
||||
return PartialContent(message)
|
||||
case http.StatusTooEarly:
|
||||
return TooEarly(message)
|
||||
case http.StatusServiceUnavailable:
|
||||
return Unavailable(message)
|
||||
case StatusChecksumMismatch:
|
||||
return ChecksumMismatch(message)
|
||||
default:
|
||||
@@ -399,6 +418,8 @@ func NewHTTPStatusCodeFromErrtype(err error) int {
|
||||
return http.StatusPartialContent
|
||||
case TooEarly:
|
||||
return http.StatusTooEarly
|
||||
case Unavailable:
|
||||
return http.StatusServiceUnavailable
|
||||
case ChecksumMismatch:
|
||||
return StatusChecksumMismatch
|
||||
default:
|
||||
|
||||
+25
-21
@@ -71,9 +71,8 @@ type RawStream struct {
|
||||
c Config
|
||||
}
|
||||
|
||||
func FromConfig(ctx context.Context, name string, cfg Config) (Stream, error) {
|
||||
var s Stream
|
||||
b := backoff.NewExponentialBackOff()
|
||||
func JetStream(ctx context.Context, name string, cfg Config) (jetstream.JetStream, error) {
|
||||
var js jetstream.JetStream
|
||||
|
||||
connect := func() error {
|
||||
var tlsConf *tls.Config
|
||||
@@ -120,27 +119,32 @@ func FromConfig(ctx context.Context, name string, cfg Config) (Stream, error) {
|
||||
return err
|
||||
}
|
||||
|
||||
jsConn, err := jetstream.New(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
js, err := jsConn.Stream(ctx, events.MainQueueName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s = &RawStream{
|
||||
js: js,
|
||||
c: cfg,
|
||||
}
|
||||
return nil
|
||||
js, err = jetstream.New(conn)
|
||||
return err
|
||||
}
|
||||
err := backoff.Retry(connect, b)
|
||||
|
||||
err := backoff.Retry(connect, backoff.NewExponentialBackOff())
|
||||
if err != nil {
|
||||
return s, errors.Wrap(err, "could not connect to nats jetstream")
|
||||
return nil, errors.Wrap(err, "could not connect to nats jetstream")
|
||||
}
|
||||
return s, nil
|
||||
return js, nil
|
||||
}
|
||||
|
||||
func FromConfig(ctx context.Context, name string, cfg Config) (Stream, error) {
|
||||
jsConn, err := JetStream(ctx, name, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
js, err := jsConn.Stream(ctx, events.MainQueueName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &RawStream{
|
||||
js: js,
|
||||
c: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *RawStream) Consume(group string, evs ...events.Unmarshaller) (<-chan Event, error) {
|
||||
|
||||
+35
-1
@@ -2,6 +2,7 @@ package stream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
@@ -11,7 +12,9 @@ import (
|
||||
|
||||
"github.com/cenkalti/backoff"
|
||||
"github.com/go-micro/plugins/v4/events/natsjs"
|
||||
"github.com/nats-io/nats.go/jetstream"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/events"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/events/raw"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/logger"
|
||||
)
|
||||
|
||||
@@ -65,7 +68,38 @@ func NatsFromConfig(connName string, disableDurability bool, cfg NatsConfig) (ev
|
||||
opts = append(opts, natsjs.DisableDurableStreams())
|
||||
}
|
||||
|
||||
return Nats(opts...)
|
||||
s, err := Nats(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// apply a MaxAge to the main queue to prevent it from filling up
|
||||
ctx := context.Background()
|
||||
jsConn, err := raw.JetStream(ctx, connName, raw.Config{
|
||||
Endpoint: cfg.Endpoint,
|
||||
Cluster: cfg.Cluster,
|
||||
TLSInsecure: cfg.TLSInsecure,
|
||||
TLSRootCACertificate: cfg.TLSRootCACertificate,
|
||||
EnableTLS: cfg.EnableTLS,
|
||||
AuthUsername: cfg.AuthUsername,
|
||||
AuthPassword: cfg.AuthPassword,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
streamCfg := jetstream.StreamConfig{
|
||||
Name: "main-queue",
|
||||
MaxAge: 7 * 24 * time.Hour,
|
||||
}
|
||||
_, err = jsConn.CreateStream(ctx, streamCfg)
|
||||
if err != nil {
|
||||
// If the stream already exists, update its configuration
|
||||
if err == jetstream.ErrStreamNameAlreadyInUse {
|
||||
_, _ = jsConn.UpdateStream(ctx, streamCfg)
|
||||
}
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// nats returns a nats streaming client
|
||||
|
||||
+13
@@ -68,6 +68,15 @@ func NewInternal(ctx context.Context, msg string) *rpc.Status {
|
||||
}
|
||||
}
|
||||
|
||||
// NewUnavailable returns a Status with CODE_UNAVAILABLE.
|
||||
func NewUnavailable(ctx context.Context, msg string) *rpc.Status {
|
||||
return &rpc.Status{
|
||||
Code: rpc.Code_CODE_UNAVAILABLE,
|
||||
Message: msg,
|
||||
Trace: getTrace(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
// NewUnauthenticated returns a Status with CODE_UNAUTHENTICATED.
|
||||
func NewUnauthenticated(ctx context.Context, err error, msg string) *rpc.Status {
|
||||
return &rpc.Status{
|
||||
@@ -191,6 +200,10 @@ func NewStatusFromErrType(ctx context.Context, msg string, err error) *rpc.Statu
|
||||
return NewUnimplemented(ctx, err, msg+":"+err.Error())
|
||||
case errtypes.BadRequest:
|
||||
return NewInvalid(ctx, msg+":"+err.Error())
|
||||
case errtypes.Unavailable:
|
||||
return NewUnavailable(ctx, msg+": "+err.Error())
|
||||
case errtypes.IsUnavailable:
|
||||
return NewUnavailable(ctx, msg+": "+err.Error())
|
||||
}
|
||||
|
||||
// map GRPC status codes coming from the auth middleware
|
||||
|
||||
+170
-60
@@ -36,9 +36,9 @@ import (
|
||||
"github.com/opencloud-eu/reva/v2/pkg/errtypes"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/events"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/events/stream"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/logger"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/rgrpc/todo/pool"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/share"
|
||||
migration "github.com/opencloud-eu/reva/v2/pkg/share/manager/jsoncs3/migrations"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/share/manager/jsoncs3/providercache"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/share/manager/jsoncs3/receivedsharecache"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/share/manager/jsoncs3/sharecache"
|
||||
@@ -48,6 +48,7 @@ import (
|
||||
"github.com/opencloud-eu/reva/v2/pkg/storagespace"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/utils"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/genproto/protobuf/field_mask"
|
||||
@@ -122,14 +123,20 @@ var (
|
||||
)
|
||||
|
||||
type config struct {
|
||||
GatewayAddr string `mapstructure:"gateway_addr"`
|
||||
MaxConcurrency int `mapstructure:"max_concurrency"`
|
||||
ProviderAddr string `mapstructure:"provider_addr"`
|
||||
ServiceUserID string `mapstructure:"service_user_id"`
|
||||
ServiceUserIdp string `mapstructure:"service_user_idp"`
|
||||
MachineAuthAPIKey string `mapstructure:"machine_auth_apikey"`
|
||||
CacheTTL int `mapstructure:"ttl"`
|
||||
Events EventOptions `mapstructure:"events"`
|
||||
GatewayAddr string `mapstructure:"gateway_addr"`
|
||||
MaxConcurrency int `mapstructure:"max_concurrency"`
|
||||
ProviderAddr string `mapstructure:"provider_addr"`
|
||||
SystemUserID string `mapstructure:"system_user_id"`
|
||||
SystemUserIdp string `mapstructure:"system_user_idp"`
|
||||
MachineAuthAPIKey string `mapstructure:"machine_auth_apikey"`
|
||||
ServiceAccountID string `mapstructure:"service_account_id"`
|
||||
ServiceAccountSecret string `mapstructure:"service_account_secret"`
|
||||
// ProviderRegistryAddr is the address of the storage registry used during
|
||||
// migrations. Defaults to GatewayAddr when empty, because in the default
|
||||
// OpenCloud deployment the registry is co-located with the gateway.
|
||||
ProviderRegistryAddr string `mapstructure:"provider_registry_addr"`
|
||||
CacheTTL int `mapstructure:"ttl"`
|
||||
Events EventOptions `mapstructure:"events"`
|
||||
}
|
||||
|
||||
// EventOptions are the configurable options for events
|
||||
@@ -145,8 +152,6 @@ type EventOptions struct {
|
||||
|
||||
// Manager implements a share manager using a cs3 storage backend with local caching
|
||||
type Manager struct {
|
||||
sync.RWMutex
|
||||
|
||||
Cache providercache.Cache // holds all shares, sharded by provider id and space id
|
||||
CreatedCache sharecache.Cache // holds the list of shares a user has created, sharded by user id
|
||||
GroupReceivedCache sharecache.Cache // holds the list of shares a group has access to, sharded by group id
|
||||
@@ -155,23 +160,25 @@ type Manager struct {
|
||||
storage metadata.Storage
|
||||
SpaceRoot *provider.ResourceId
|
||||
|
||||
initialized bool
|
||||
ready chan struct{} // closed once initialize() has completed successfully
|
||||
migrationsDone chan struct{} // closed once doMigrations() has returned on this instance
|
||||
|
||||
MaxConcurrency int
|
||||
|
||||
gatewaySelector pool.Selectable[gatewayv1beta1.GatewayAPIClient]
|
||||
eventStream events.Stream
|
||||
logger *zerolog.Logger
|
||||
}
|
||||
|
||||
// NewDefault returns a new manager instance with default dependencies
|
||||
func NewDefault(m map[string]interface{}) (share.Manager, error) {
|
||||
func NewDefault(m map[string]interface{}, logger *zerolog.Logger) (share.Manager, error) {
|
||||
c := &config{}
|
||||
if err := mapstructure.Decode(m, c); err != nil {
|
||||
err = errors.Wrap(err, "error creating a new manager")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s, err := metadata.NewCS3Storage(c.ProviderAddr, c.ProviderAddr, c.ServiceUserID, c.ServiceUserIdp, c.MachineAuthAPIKey)
|
||||
s, err := metadata.NewCS3Storage(c.ProviderAddr, c.ProviderAddr, c.SystemUserID, c.SystemUserIdp, c.MachineAuthAPIKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -189,11 +196,34 @@ func NewDefault(m map[string]interface{}) (share.Manager, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return New(s, gatewaySelector, c.CacheTTL, es, c.MaxConcurrency)
|
||||
mgr, err := New(s, logger, gatewaySelector, c.CacheTTL, es, c.MaxConcurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
providerRegistryAddr := c.ProviderRegistryAddr
|
||||
if providerRegistryAddr == "" {
|
||||
providerRegistryAddr = c.GatewayAddr
|
||||
}
|
||||
mgr.RunMigrations(migration.MigrationConfig{
|
||||
ServiceAccountID: c.ServiceAccountID,
|
||||
ServiceAccountSecret: c.ServiceAccountSecret,
|
||||
ProviderRegistryAddr: providerRegistryAddr,
|
||||
})
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
// New returns a new manager instance.
|
||||
func New(s metadata.Storage, gatewaySelector pool.Selectable[gatewayv1beta1.GatewayAPIClient], ttlSeconds int, es events.Stream, maxconcurrency int) (*Manager, error) {
|
||||
func New(s metadata.Storage,
|
||||
logger *zerolog.Logger,
|
||||
gatewaySelector pool.Selectable[gatewayv1beta1.GatewayAPIClient],
|
||||
ttlSeconds int,
|
||||
es events.Stream,
|
||||
maxconcurrency int,
|
||||
) (*Manager, error) {
|
||||
if logger == nil {
|
||||
nop := zerolog.Nop()
|
||||
logger = &nop
|
||||
}
|
||||
ttl := time.Duration(ttlSeconds) * time.Second
|
||||
|
||||
m := &Manager{
|
||||
@@ -205,13 +235,38 @@ func New(s metadata.Storage, gatewaySelector pool.Selectable[gatewayv1beta1.Gate
|
||||
gatewaySelector: gatewaySelector,
|
||||
eventStream: es,
|
||||
MaxConcurrency: maxconcurrency,
|
||||
logger: logger,
|
||||
ready: make(chan struct{}),
|
||||
// migrationsDone is open (blocking) by default. It is closed by
|
||||
// doMigrations when all migrations complete, or by SkipMigrations for
|
||||
// callers (e.g. tests) that do not run migrations at all.
|
||||
migrationsDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Initialize the metadata storage connection in the background, retrying
|
||||
// with exponential backoff if the backend is not yet available.
|
||||
go func() {
|
||||
backoff := time.Second
|
||||
for {
|
||||
if err := m.initialize(context.Background()); err != nil {
|
||||
logger.Info().Err(err).Dur("backoff", backoff).Msg("share manager: metadata storage initialization failed, retrying")
|
||||
time.Sleep(backoff)
|
||||
if backoff < 30*time.Second {
|
||||
backoff *= 2
|
||||
}
|
||||
continue
|
||||
}
|
||||
logger.Debug().Msg("share manager: initialization succeeded")
|
||||
close(m.ready)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
// listen for events
|
||||
if m.eventStream != nil {
|
||||
ch, err := events.Consume(m.eventStream, "jsoncs3sharemanager", _registeredEvents...)
|
||||
if err != nil {
|
||||
appctx.GetLogger(context.Background()).Error().Err(err).Msg("error consuming events")
|
||||
logger.Error().Err(err).Msg("error consuming events")
|
||||
}
|
||||
go m.ProcessEvents(ch)
|
||||
}
|
||||
@@ -219,23 +274,13 @@ func New(s metadata.Storage, gatewaySelector pool.Selectable[gatewayv1beta1.Gate
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// initialize connects to the metadata storage backend and ensures the required
|
||||
// directory structure exists. It is called once at startup from a background
|
||||
// goroutine (see New) and must not be called concurrently.
|
||||
func (m *Manager) initialize(ctx context.Context) error {
|
||||
_, span := appctx.GetTracerProvider(ctx).Tracer(tracerName).Start(ctx, "initialize")
|
||||
defer span.End()
|
||||
if m.initialized {
|
||||
span.SetStatus(codes.Ok, "already initialized")
|
||||
return nil
|
||||
}
|
||||
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
if m.initialized { // check if initialization happened while grabbing the lock
|
||||
span.SetStatus(codes.Ok, "initialized while grabbing lock")
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx = context.Background()
|
||||
err := m.storage.Init(ctx, "jsoncs3-share-manager-metadata")
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
@@ -261,21 +306,85 @@ func (m *Manager) initialize(ctx context.Context) error {
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return err
|
||||
}
|
||||
err = m.storage.MakeDirIfNotExist(ctx, "migrations")
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
m.initialized = true
|
||||
span.SetStatus(codes.Ok, "initialized")
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitForInit blocks until the background initialization goroutine has
|
||||
// successfully completed, or until ctx is cancelled.
|
||||
func (m *Manager) waitForInit(ctx context.Context) error {
|
||||
select {
|
||||
case <-m.ready:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return errors.Wrap(ctx.Err(), "share manager not yet initialized")
|
||||
}
|
||||
}
|
||||
|
||||
// waitForMigrations blocks until both storage initialization and all data
|
||||
// migrations have completed on this instance, or until ctx is cancelled.
|
||||
// It is a strict superset of waitForInit and should be used by write operations
|
||||
// to ensure no writes race with an in-progress migration.
|
||||
func (m *Manager) waitForMigrations(ctx context.Context) error {
|
||||
select {
|
||||
case <-m.ready:
|
||||
case <-ctx.Done():
|
||||
return errors.Wrap(ctx.Err(), "share manager not yet initialized")
|
||||
}
|
||||
select {
|
||||
case <-m.migrationsDone:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return errors.Wrap(ctx.Err(), "share manager migrations not yet complete")
|
||||
}
|
||||
}
|
||||
|
||||
// RunMigrations starts data migrations in a background goroutine. It should be
|
||||
// called once after New() in production server startup. Callers that do not
|
||||
// need migrations should call SkipMigrations instead to unblock write operations.
|
||||
func (m *Manager) RunMigrations(cfg migration.MigrationConfig) {
|
||||
go m.doMigrations(cfg)
|
||||
}
|
||||
|
||||
// SkipMigrations unblocks write operations on this instance without running
|
||||
// any migrations. It must be called when RunMigrations will not be called,
|
||||
// for example in tests.
|
||||
func (m *Manager) SkipMigrations() {
|
||||
close(m.migrationsDone)
|
||||
}
|
||||
|
||||
func (m *Manager) doMigrations(cfg migration.MigrationConfig) {
|
||||
// Always close migrationsDone when this goroutine exits, whether migrations
|
||||
// ran, were skipped, or failed. This unblocks write operations on this
|
||||
// instance. Non-winning instances are held here by acquireLock until the
|
||||
// winning instance finishes, so the close happens only after the storage
|
||||
// state is fully migrated.
|
||||
defer close(m.migrationsDone)
|
||||
if err := m.waitForInit(context.Background()); err != nil {
|
||||
m.logger.Error().Err(err).Msg("share manager: aborting migrations, manager did not initialize")
|
||||
return
|
||||
}
|
||||
m.logger.Debug().Msg("migrations start")
|
||||
migrations := migration.New(*m.logger, m.gatewaySelector, m.storage, cfg, m, m)
|
||||
migrations.RunMigrations()
|
||||
}
|
||||
|
||||
func (m *Manager) ProcessEvents(ch <-chan events.Event) {
|
||||
log := logger.New()
|
||||
log := m.logger
|
||||
ctx := context.Background()
|
||||
if err := m.waitForInit(ctx); err != nil {
|
||||
log.Error().Err(err).Msg("share manager: error waiting for initialization")
|
||||
return
|
||||
}
|
||||
for event := range ch {
|
||||
ctx := context.Background()
|
||||
|
||||
if err := m.initialize(ctx); err != nil {
|
||||
log.Error().Err(err).Msg("error initializing manager")
|
||||
}
|
||||
|
||||
if ev, ok := event.Event.(events.SpaceDeleted); ok {
|
||||
log.Debug().Msgf("space deleted event: %v", ev)
|
||||
go func() { m.purgeSpace(ctx, ev.ID) }()
|
||||
@@ -287,7 +396,7 @@ func (m *Manager) ProcessEvents(ch <-chan events.Event) {
|
||||
func (m *Manager) Share(ctx context.Context, md *provider.ResourceInfo, g *collaboration.ShareGrant) (*collaboration.Share, error) {
|
||||
ctx, span := appctx.GetTracerProvider(ctx).Tracer(tracerName).Start(ctx, "Share")
|
||||
defer span.End()
|
||||
if err := m.initialize(ctx); err != nil {
|
||||
if err := m.waitForMigrations(ctx); err != nil {
|
||||
span.RecordError(err)
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return nil, err
|
||||
@@ -436,7 +545,7 @@ func (m *Manager) GetShare(ctx context.Context, ref *collaboration.ShareReferenc
|
||||
ctx, span := appctx.GetTracerProvider(ctx).Tracer(tracerName).Start(ctx, "GetShare")
|
||||
defer span.End()
|
||||
sublog := appctx.GetLogger(ctx).With().Str("id", ref.GetId().GetOpaqueId()).Str("key", ref.GetKey().String()).Str("driver", "jsoncs3").Str("handler", "GetShare").Logger()
|
||||
if err := m.initialize(ctx); err != nil {
|
||||
if err := m.waitForInit(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -494,7 +603,7 @@ func (m *Manager) Unshare(ctx context.Context, ref *collaboration.ShareReference
|
||||
ctx, span := appctx.GetTracerProvider(ctx).Tracer(tracerName).Start(ctx, "Unshare")
|
||||
defer span.End()
|
||||
|
||||
if err := m.initialize(ctx); err != nil {
|
||||
if err := m.waitForMigrations(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -511,7 +620,7 @@ func (m *Manager) UpdateShare(ctx context.Context, ref *collaboration.ShareRefer
|
||||
ctx, span := appctx.GetTracerProvider(ctx).Tracer(tracerName).Start(ctx, "UpdateShare")
|
||||
defer span.End()
|
||||
|
||||
if err := m.initialize(ctx); err != nil {
|
||||
if err := m.waitForMigrations(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -599,7 +708,7 @@ func (m *Manager) ListShares(ctx context.Context, filters []*collaboration.Filte
|
||||
ctx, span := appctx.GetTracerProvider(ctx).Tracer(tracerName).Start(ctx, "ListShares")
|
||||
defer span.End()
|
||||
|
||||
if err := m.initialize(ctx); err != nil {
|
||||
if err := m.waitForInit(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -816,7 +925,7 @@ func (m *Manager) ListReceivedShares(ctx context.Context, filters []*collaborati
|
||||
defer span.End()
|
||||
sublog := appctx.GetLogger(ctx).With().Str("driver", "jsoncs3").Str("handler", "ListReceivedShares").Logger()
|
||||
|
||||
if err := m.initialize(ctx); err != nil {
|
||||
if err := m.waitForInit(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1012,7 +1121,7 @@ func (m *Manager) convert(ctx context.Context, userID string, s *collaboration.S
|
||||
|
||||
// GetReceivedShare returns the information for a received share.
|
||||
func (m *Manager) GetReceivedShare(ctx context.Context, ref *collaboration.ShareReference) (*collaboration.ReceivedShare, error) {
|
||||
if err := m.initialize(ctx); err != nil {
|
||||
if err := m.waitForInit(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1056,7 +1165,7 @@ func (m *Manager) UpdateReceivedShare(ctx context.Context, receivedShare *collab
|
||||
ctx, span := appctx.GetTracerProvider(ctx).Tracer(tracerName).Start(ctx, "UpdateReceivedShare")
|
||||
defer span.End()
|
||||
|
||||
if err := m.initialize(ctx); err != nil {
|
||||
if err := m.waitForMigrations(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1103,8 +1212,8 @@ func updateShareID(share *collaboration.Share) {
|
||||
|
||||
// Load imports shares and received shares from channels (e.g. during migration)
|
||||
func (m *Manager) Load(ctx context.Context, shareChan <-chan *collaboration.Share, receivedShareChan <-chan share.ReceivedShareWithUser) error {
|
||||
log := appctx.GetLogger(ctx)
|
||||
if err := m.initialize(ctx); err != nil {
|
||||
l := m.logger
|
||||
if err := m.waitForInit(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1119,14 +1228,14 @@ func (m *Manager) Load(ctx context.Context, shareChan <-chan *collaboration.Shar
|
||||
updateShareID(s)
|
||||
}
|
||||
if err := m.Cache.Add(context.Background(), s.GetResourceId().GetStorageId(), s.GetResourceId().GetSpaceId(), s.Id.OpaqueId, s); err != nil {
|
||||
log.Error().Err(err).Interface("share", s).Msg("error persisting share")
|
||||
l.Error().Err(err).Interface("share", s).Msg("error persisting share")
|
||||
} else {
|
||||
log.Debug().Str("storageid", s.GetResourceId().GetStorageId()).Str("spaceid", s.GetResourceId().GetSpaceId()).Str("shareid", s.Id.OpaqueId).Msg("imported share")
|
||||
l.Debug().Str("storageid", s.GetResourceId().GetStorageId()).Str("spaceid", s.GetResourceId().GetSpaceId()).Str("shareid", s.Id.OpaqueId).Msg("imported share")
|
||||
}
|
||||
if err := m.CreatedCache.Add(ctx, s.GetCreator().GetOpaqueId(), s.Id.OpaqueId); err != nil {
|
||||
log.Error().Err(err).Interface("share", s).Msg("error persisting created cache")
|
||||
l.Error().Err(err).Interface("share", s).Msg("error persisting created cache")
|
||||
} else {
|
||||
log.Debug().Str("creatorid", s.GetCreator().GetOpaqueId()).Str("shareid", s.Id.OpaqueId).Msg("updated created cache")
|
||||
l.Debug().Str("creatorid", s.GetCreator().GetOpaqueId()).Str("shareid", s.Id.OpaqueId).Msg("updated created cache")
|
||||
}
|
||||
}
|
||||
wg.Done()
|
||||
@@ -1137,18 +1246,19 @@ func (m *Manager) Load(ctx context.Context, shareChan <-chan *collaboration.Shar
|
||||
if !shareIsRoutable(s.ReceivedShare.GetShare()) {
|
||||
updateShareID(s.ReceivedShare.GetShare())
|
||||
}
|
||||
switch s.ReceivedShare.Share.Grantee.Type {
|
||||
case provider.GranteeType_GRANTEE_TYPE_USER:
|
||||
if err := m.UserReceivedStates.Add(context.Background(), s.ReceivedShare.GetShare().GetGrantee().GetUserId().GetOpaqueId(), s.ReceivedShare.GetShare().GetResourceId().GetSpaceId(), s.ReceivedShare); err != nil {
|
||||
log.Error().Err(err).Interface("received share", s).Msg("error persisting received share for user")
|
||||
if s.UserID != nil {
|
||||
spaceid := s.ReceivedShare.GetShare().GetResourceId().GetStorageId() + shareid.IDDelimiter + s.ReceivedShare.GetShare().GetResourceId().GetSpaceId()
|
||||
if err := m.UserReceivedStates.Add(context.Background(), s.UserID.GetOpaqueId(), spaceid, s.ReceivedShare); err != nil {
|
||||
l.Error().Err(err).Interface("received share", s).Msg("error persisting received share for user")
|
||||
} else {
|
||||
log.Debug().Str("userid", s.ReceivedShare.GetShare().GetGrantee().GetUserId().GetOpaqueId()).Str("spaceid", s.ReceivedShare.GetShare().GetResourceId().GetSpaceId()).Str("shareid", s.ReceivedShare.GetShare().Id.OpaqueId).Msg("updated received share userdata")
|
||||
l.Debug().Str("userid", s.UserID.GetOpaqueId()).Str("spaceid", spaceid).Str("shareid", s.ReceivedShare.GetShare().Id.OpaqueId).Msg("updated received share userdata")
|
||||
}
|
||||
case provider.GranteeType_GRANTEE_TYPE_GROUP:
|
||||
}
|
||||
if s.ReceivedShare.Share.Grantee.Type == provider.GranteeType_GRANTEE_TYPE_GROUP && s.UserID == nil {
|
||||
if err := m.GroupReceivedCache.Add(context.Background(), s.ReceivedShare.GetShare().GetGrantee().GetGroupId().GetOpaqueId(), s.ReceivedShare.GetShare().GetId().GetOpaqueId()); err != nil {
|
||||
log.Error().Err(err).Interface("received share", s).Msg("error persisting received share to group cache")
|
||||
l.Error().Err(err).Interface("received share", s).Msg("error persisting received share to group cache")
|
||||
} else {
|
||||
log.Debug().Str("groupid", s.ReceivedShare.GetShare().GetGrantee().GetGroupId().GetOpaqueId()).Str("shareid", s.ReceivedShare.GetShare().Id.OpaqueId).Msg("updated received share group cache")
|
||||
l.Debug().Str("groupid", s.ReceivedShare.GetShare().GetGrantee().GetGroupId().GetOpaqueId()).Str("shareid", s.ReceivedShare.GetShare().Id.OpaqueId).Msg("updated received share group cache")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1220,7 +1330,7 @@ func (m *Manager) removeShare(ctx context.Context, s *collaboration.Share, skipS
|
||||
func (m *Manager) CleanupStaleShares(ctx context.Context) {
|
||||
log := appctx.GetLogger(ctx)
|
||||
|
||||
if err := m.initialize(ctx); err != nil {
|
||||
if err := m.waitForMigrations(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
Generated
Vendored
+435
@@ -0,0 +1,435 @@
|
||||
// Copyright 2026 OpenCloud GmbH
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// In applying this license, CERN does not waive the privileges and immunities
|
||||
// granted to it by virtue of its status as an Intergovernmental Organization
|
||||
// or submit itself to any jurisdiction.
|
||||
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff"
|
||||
grouppb "github.com/cs3org/go-cs3apis/cs3/identity/group/v1beta1"
|
||||
userpb "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
|
||||
rpc "github.com/cs3org/go-cs3apis/cs3/rpc/v1beta1"
|
||||
collaboration "github.com/cs3org/go-cs3apis/cs3/sharing/collaboration/v1beta1"
|
||||
provider "github.com/cs3org/go-cs3apis/cs3/storage/provider/v1beta1"
|
||||
registry "github.com/cs3org/go-cs3apis/cs3/storage/registry/v1beta1"
|
||||
typesv1beta1 "github.com/cs3org/go-cs3apis/cs3/types/v1beta1"
|
||||
"github.com/google/uuid"
|
||||
ctxpkg "github.com/opencloud-eu/reva/v2/pkg/ctx"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/errtypes"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/rgrpc/todo/pool"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/share"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/share/manager/jsoncs3/shareid"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/utils"
|
||||
"github.com/rs/zerolog"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// storageProvider is the narrow subset of provider.ProviderAPIClient that the
|
||||
// migration actually uses. Keeping it narrow makes test stubs trivial to write.
|
||||
type storageProvider interface {
|
||||
ListGrants(ctx context.Context, in *provider.ListGrantsRequest, opts ...grpc.CallOption) (*provider.ListGrantsResponse, error)
|
||||
}
|
||||
|
||||
type ImportSpaceMembersMigration struct {
|
||||
cfg config
|
||||
sharesChan chan *collaboration.Share
|
||||
receivedChan chan share.ReceivedShareWithUser
|
||||
userCache map[string]*userpb.UserId
|
||||
groupCache map[string]*grouppb.GroupId
|
||||
providerResolver func(context.Context, *provider.StorageSpace) (storageProvider, error)
|
||||
}
|
||||
|
||||
func init() {
|
||||
registerMigration(&ImportSpaceMembersMigration{})
|
||||
}
|
||||
|
||||
func (m *ImportSpaceMembersMigration) Initialize(cfg config) {
|
||||
m.cfg = cfg
|
||||
m.sharesChan = make(chan *collaboration.Share)
|
||||
m.receivedChan = make(chan share.ReceivedShareWithUser)
|
||||
m.userCache = make(map[string]*userpb.UserId)
|
||||
m.groupCache = make(map[string]*grouppb.GroupId)
|
||||
m.providerResolver = func(ctx context.Context, space *provider.StorageSpace) (storageProvider, error) {
|
||||
return m.storageProviderForSpace(ctx, space)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ImportSpaceMembersMigration) Name() string {
|
||||
return "import_space_members"
|
||||
}
|
||||
|
||||
func (m *ImportSpaceMembersMigration) Version() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (m *ImportSpaceMembersMigration) Migrate() error {
|
||||
gwc, err := m.cfg.gatewaySelector.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
svcCtx, err := utils.GetServiceUserContextWithContext(context.Background(), gwc, m.cfg.serviceAccountID, m.cfg.serviceAccountSecret)
|
||||
if err != nil {
|
||||
m.cfg.logger.Error().Err(err).Msg("failed to get service user context for migration")
|
||||
return err
|
||||
}
|
||||
// List all project spaces.
|
||||
listRes, err := gwc.ListStorageSpaces(svcCtx, &provider.ListStorageSpacesRequest{
|
||||
Opaque: utils.AppendPlainToOpaque(nil, "unrestricted", "true"),
|
||||
Filters: []*provider.ListStorageSpacesRequest_Filter{
|
||||
{
|
||||
Type: provider.ListStorageSpacesRequest_Filter_TYPE_SPACE_TYPE,
|
||||
Term: &provider.ListStorageSpacesRequest_Filter_SpaceType{SpaceType: "project"},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
m.cfg.logger.Error().Err(err).Msg("space-membership migration: failed to list storage spaces")
|
||||
return err
|
||||
}
|
||||
|
||||
if listRes.GetStatus().GetCode() != rpc.Code_CODE_OK {
|
||||
m.cfg.logger.Error().Str("status", listRes.GetStatus().GetMessage()).Msg("space-membership migration: ListStorageSpaces returned non-OK status")
|
||||
return errtypes.InternalError("ListStorageSpaces")
|
||||
}
|
||||
|
||||
spaces := listRes.GetStorageSpaces()
|
||||
m.cfg.logger.Info().Int("spaces", len(spaces)).Msg("Starting migration")
|
||||
|
||||
// loadCtx is cancelled when the producer finishes (or fails) so that the
|
||||
// Load goroutine — which blocks reading from the channels — is not left
|
||||
// waiting forever if we return early from an error.
|
||||
loadCtx, cancelLoad := context.WithCancel(svcCtx)
|
||||
defer cancelLoad()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var loaderError error
|
||||
wg.Go(func() {
|
||||
loaderError = m.cfg.loader.Load(loadCtx, m.sharesChan, m.receivedChan)
|
||||
})
|
||||
|
||||
migrated := 0
|
||||
for _, space := range spaces {
|
||||
sharesCreated, err := m.migrateSpace(loadCtx, space)
|
||||
if err != nil {
|
||||
m.cfg.logger.Error().Err(err).Str("space", space.GetId().GetOpaqueId()).Msg("failed to migrate space; continuing with remaining spaces")
|
||||
continue
|
||||
}
|
||||
migrated++
|
||||
m.cfg.logger.Debug().
|
||||
Str("space", space.GetId().GetOpaqueId()).
|
||||
Int("shares_created", sharesCreated).
|
||||
Msg("space migrated")
|
||||
if migrated%10 == 0 {
|
||||
m.cfg.logger.Info().
|
||||
Int("migrated", migrated).
|
||||
Int("total", len(spaces)).
|
||||
Msg("migration progress")
|
||||
}
|
||||
}
|
||||
close(m.receivedChan)
|
||||
close(m.sharesChan)
|
||||
|
||||
wg.Wait()
|
||||
m.cfg.logger.Info().Err(loaderError).Int("migrated", migrated).Int("total", len(spaces)).Msg("Migration finished")
|
||||
return loaderError
|
||||
}
|
||||
|
||||
func (m *ImportSpaceMembersMigration) migrateSpace(ctx context.Context, space *provider.StorageSpace) (int, error) {
|
||||
spClient, err := m.providerResolver(ctx, space)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
ref := &provider.Reference{ResourceId: space.GetRoot()}
|
||||
grantsRes, err := spClient.ListGrants(ctx, &provider.ListGrantsRequest{Ref: ref})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if grantsRes.GetStatus().GetCode() != rpc.Code_CODE_OK {
|
||||
return 0, errtypes.NewErrtypeFromStatus(grantsRes.GetStatus())
|
||||
}
|
||||
|
||||
sharesCreated := 0
|
||||
for _, grant := range grantsRes.GetGrants() {
|
||||
share, receivedShares, err := m.spaceGrantToShares(ctx, grant, space)
|
||||
if err != nil {
|
||||
m.cfg.logger.Error().Err(err).
|
||||
Interface("grant", grant).
|
||||
Msg("Failed to convert grant to shares")
|
||||
continue
|
||||
}
|
||||
if share == nil {
|
||||
// share already existed; nothing to import for this grant
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case m.sharesChan <- share:
|
||||
case <-ctx.Done():
|
||||
return sharesCreated, ctx.Err()
|
||||
}
|
||||
for _, rs := range receivedShares {
|
||||
select {
|
||||
case m.receivedChan <- rs:
|
||||
case <-ctx.Done():
|
||||
return sharesCreated, ctx.Err()
|
||||
}
|
||||
}
|
||||
sharesCreated++
|
||||
}
|
||||
return sharesCreated, nil
|
||||
}
|
||||
|
||||
// resolveRetries is the maximum number of times resolveUserID / resolveGroupID
|
||||
// will retry after receiving an errtypes.Unavailable response (LDAP down).
|
||||
const resolveRetries = 10
|
||||
|
||||
// retryOnUnavailable calls op, retrying with exponential backoff whenever op
|
||||
// returns errtypes.Unavailable. Any other error (including context
|
||||
// cancellation) stops the loop immediately and is returned as-is.
|
||||
// Retries are capped at resolveRetries attempts and respect ctx cancellation.
|
||||
func retryOnUnavailable(ctx context.Context, log zerolog.Logger, op func() error) error {
|
||||
b := backoff.WithContext(
|
||||
backoff.WithMaxRetries(backoff.NewExponentialBackOff(), resolveRetries),
|
||||
ctx,
|
||||
)
|
||||
notify := func(err error, d time.Duration) {
|
||||
log.Warn().Err(err).Dur("retry_in", d).Msg("identity provider temporarily unavailable, retrying")
|
||||
}
|
||||
return backoff.RetryNotify(func() error {
|
||||
err := op()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if _, ok := err.(errtypes.Unavailable); ok {
|
||||
return err // transient — keep retrying
|
||||
}
|
||||
return backoff.Permanent(err) // permanent — stop immediately
|
||||
}, b, notify)
|
||||
}
|
||||
|
||||
func (m *ImportSpaceMembersMigration) resolveUserID(ctx context.Context, opaqueID string) (*userpb.UserId, error) {
|
||||
if id, ok := m.userCache[opaqueID]; ok {
|
||||
return id, nil
|
||||
}
|
||||
var id *userpb.UserId
|
||||
err := retryOnUnavailable(ctx, m.cfg.logger, func() error {
|
||||
gwc, err := m.cfg.gatewaySelector.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res, err := gwc.GetUser(ctx, &userpb.GetUserRequest{
|
||||
UserId: &userpb.UserId{OpaqueId: opaqueID},
|
||||
SkipFetchingUserGroups: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if res.GetStatus().GetCode() != rpc.Code_CODE_OK {
|
||||
// errtypes.NewErrtypeFromStatus maps CODE_UNAVAILABLE → errtypes.Unavailable,
|
||||
// which retryOnUnavailable will retry; all other codes are treated as permanent.
|
||||
return errtypes.NewErrtypeFromStatus(res.GetStatus())
|
||||
}
|
||||
id = res.GetUser().GetId()
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.userCache[opaqueID] = id
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (m *ImportSpaceMembersMigration) resolveGroupID(ctx context.Context, opaqueID string) (*grouppb.GroupId, error) {
|
||||
if id, ok := m.groupCache[opaqueID]; ok {
|
||||
return id, nil
|
||||
}
|
||||
var id *grouppb.GroupId
|
||||
err := retryOnUnavailable(ctx, m.cfg.logger, func() error {
|
||||
gwc, err := m.cfg.gatewaySelector.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res, err := gwc.GetGroup(ctx, &grouppb.GetGroupRequest{
|
||||
GroupId: &grouppb.GroupId{OpaqueId: opaqueID},
|
||||
SkipFetchingMembers: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if res.GetStatus().GetCode() != rpc.Code_CODE_OK {
|
||||
return errtypes.NewErrtypeFromStatus(res.GetStatus())
|
||||
}
|
||||
id = res.GetGroup().GetId()
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.groupCache[opaqueID] = id
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (m *ImportSpaceMembersMigration) spaceGrantToShares(ctx context.Context, grant *provider.Grant, space *provider.StorageSpace) (*collaboration.Share, []share.ReceivedShareWithUser, error) {
|
||||
// The grantee ids as persisted on disk do not have an IDP or type stored as
|
||||
// part of the userid/groupid. Resolve them via the gateway so we get the
|
||||
// full userid
|
||||
switch grant.GetGrantee().GetType() {
|
||||
case provider.GranteeType_GRANTEE_TYPE_GROUP:
|
||||
groupID, err := m.resolveGroupID(ctx, grant.GetGrantee().GetGroupId().GetOpaqueId())
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("resolve group %s: %w", grant.GetGrantee().GetGroupId().GetOpaqueId(), err)
|
||||
}
|
||||
grant.Grantee.Id = &provider.Grantee_GroupId{GroupId: groupID}
|
||||
case provider.GranteeType_GRANTEE_TYPE_USER:
|
||||
userID, err := m.resolveUserID(ctx, grant.GetGrantee().GetUserId().GetOpaqueId())
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("resolve user %s: %w", grant.GetGrantee().GetUserId().GetOpaqueId(), err)
|
||||
}
|
||||
grant.Grantee.Id = &provider.Grantee_UserId{UserId: userID}
|
||||
}
|
||||
|
||||
ref := &collaboration.ShareReference{
|
||||
Spec: &collaboration.ShareReference_Key{
|
||||
Key: &collaboration.ShareKey{
|
||||
ResourceId: space.GetRoot(),
|
||||
Grantee: grant.GetGrantee(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx = ctxpkg.ContextSetUser(ctx, &userpb.User{Id: grant.Creator})
|
||||
if s, err := m.cfg.manager.GetShare(ctx, ref); err == nil {
|
||||
// FIXME: Verify the actual grants?
|
||||
m.cfg.logger.Debug().Interface("share", s).Msg("share already exists")
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
ts := utils.TSNow()
|
||||
shareID := shareid.Encode(space.GetRoot().GetStorageId(), space.GetRoot().GetSpaceId(), uuid.NewString())
|
||||
|
||||
creator := grant.GetCreator()
|
||||
if creator.Type == userpb.UserType_USER_TYPE_INVALID {
|
||||
creator = nil
|
||||
}
|
||||
newShare := &collaboration.Share{
|
||||
Id: &collaboration.ShareId{OpaqueId: shareID},
|
||||
ResourceId: space.GetRoot(),
|
||||
Permissions: &collaboration.SharePermissions{Permissions: grant.GetPermissions()},
|
||||
Grantee: grant.GetGrantee(),
|
||||
Expiration: grant.GetExpiration(),
|
||||
Owner: creator,
|
||||
Creator: creator,
|
||||
Ctime: ts,
|
||||
Mtime: ts,
|
||||
}
|
||||
|
||||
var newReceivedShares []share.ReceivedShareWithUser
|
||||
switch grant.GetGrantee().GetType() {
|
||||
case provider.GranteeType_GRANTEE_TYPE_GROUP:
|
||||
gwc, err := m.cfg.gatewaySelector.Next()
|
||||
if err != nil {
|
||||
m.cfg.logger.Error().Err(err).Msg("Failed to get gateway client")
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
gr, err := gwc.GetMembers(ctx, &grouppb.GetMembersRequest{
|
||||
GroupId: grant.GetGrantee().GetGroupId(),
|
||||
})
|
||||
if err != nil {
|
||||
m.cfg.logger.Error().Err(err).Msg("Failed to expand group membership")
|
||||
return nil, nil, err
|
||||
}
|
||||
if gr.GetStatus().GetCode() != rpc.Code_CODE_OK {
|
||||
m.cfg.logger.Error().Str("Status", gr.GetStatus().GetMessage()).Msg("Failed to expand group membership")
|
||||
return nil, nil, errtypes.NewErrtypeFromStatus(gr.GetStatus())
|
||||
}
|
||||
for _, u := range gr.GetMembers() {
|
||||
newReceivedShares = append(newReceivedShares, share.ReceivedShareWithUser{
|
||||
UserID: u,
|
||||
ReceivedShare: &collaboration.ReceivedShare{
|
||||
Share: newShare,
|
||||
State: collaboration.ShareState_SHARE_STATE_ACCEPTED,
|
||||
},
|
||||
})
|
||||
}
|
||||
// Also add a group-level entry (UserID == nil) so the group cache is populated.
|
||||
newReceivedShares = append(newReceivedShares, share.ReceivedShareWithUser{
|
||||
UserID: nil,
|
||||
ReceivedShare: &collaboration.ReceivedShare{
|
||||
Share: newShare,
|
||||
State: collaboration.ShareState_SHARE_STATE_ACCEPTED,
|
||||
},
|
||||
})
|
||||
case provider.GranteeType_GRANTEE_TYPE_USER:
|
||||
newReceivedShares = append(newReceivedShares, share.ReceivedShareWithUser{
|
||||
UserID: grant.GetGrantee().GetUserId(),
|
||||
ReceivedShare: &collaboration.ReceivedShare{
|
||||
Share: newShare,
|
||||
State: collaboration.ShareState_SHARE_STATE_ACCEPTED,
|
||||
},
|
||||
})
|
||||
}
|
||||
return newShare, newReceivedShares, nil
|
||||
}
|
||||
|
||||
// storageProviderForSpace resolves the storageprovider responsible for the
|
||||
// given storage space and returns a dialled client. In the default opencloud
|
||||
// deployment the storage registry is co-located with the gateway, so
|
||||
// the GatewayAddr is used as the registry address.
|
||||
func (m *ImportSpaceMembersMigration) storageProviderForSpace(ctx context.Context, space *provider.StorageSpace) (provider.ProviderAPIClient, error) {
|
||||
|
||||
srClient, err := pool.GetStorageRegistryClient(m.cfg.providerRegistryAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get storage registry client: %w", err)
|
||||
}
|
||||
|
||||
spaceJSON, err := json.Marshal(space)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal space: %w", err)
|
||||
}
|
||||
|
||||
res, err := srClient.GetStorageProviders(ctx, ®istry.GetStorageProvidersRequest{
|
||||
Opaque: &typesv1beta1.Opaque{
|
||||
Map: map[string]*typesv1beta1.OpaqueEntry{
|
||||
"space": {
|
||||
Decoder: "json",
|
||||
Value: spaceJSON,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetStorageProviders: %w", err)
|
||||
}
|
||||
if len(res.GetProviders()) == 0 {
|
||||
return nil, fmt.Errorf("no storage provider found for space %s", space.GetId().GetOpaqueId())
|
||||
}
|
||||
|
||||
c, err := pool.GetStorageProviderServiceClient(res.GetProviders()[0].GetAddress())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial storage provider: %w", err)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
Generated
Vendored
+353
@@ -0,0 +1,353 @@
|
||||
// Copyright 2026 OpenCloud GmbH
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// In applying this license, CERN does not waive the privileges and immunities
|
||||
// granted to it by virtue of its status as an Intergovernmental Organization
|
||||
// or submit itself to any jurisdiction.
|
||||
|
||||
package migration
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"slices"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
gatewayv1beta1 "github.com/cs3org/go-cs3apis/cs3/gateway/v1beta1"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/errtypes"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/rgrpc/todo/pool"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/share"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/storage/utils/metadata"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
const stateFile = "migrations/state.json"
|
||||
|
||||
const (
|
||||
lockFile = "migrations/lock.json"
|
||||
lockTTL = time.Minute
|
||||
lockHeartbeatInterval = 20 * time.Second
|
||||
)
|
||||
|
||||
// lockPollInterval is how long acquireLock sleeps between retries when the
|
||||
// lock is held by another instance. Declared as a variable so tests can
|
||||
// shorten it without rebuilding.
|
||||
var lockPollInterval = 5 * time.Second
|
||||
|
||||
// lockData is the content written to the lock file.
|
||||
type lockData struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
InstanceID string `json:"instance_id"`
|
||||
}
|
||||
|
||||
type migration interface {
|
||||
Name() string
|
||||
Version() int
|
||||
Initialize(config)
|
||||
Migrate() error
|
||||
}
|
||||
|
||||
// persistedState is the on-disk representation of the migration state.
|
||||
type persistedState struct {
|
||||
Version int `json:"version"`
|
||||
}
|
||||
|
||||
type state struct {
|
||||
version int
|
||||
}
|
||||
|
||||
// MigrationConfig holds all caller-supplied options for a migration run.
|
||||
// It is intentionally a plain struct so that new fields can be added without
|
||||
// changing function signatures throughout the call chain.
|
||||
type MigrationConfig struct {
|
||||
ServiceAccountID string
|
||||
ServiceAccountSecret string
|
||||
ProviderRegistryAddr string
|
||||
}
|
||||
|
||||
type config struct {
|
||||
logger zerolog.Logger
|
||||
gatewaySelector pool.Selectable[gatewayv1beta1.GatewayAPIClient]
|
||||
storage metadata.Storage
|
||||
serviceAccountID string
|
||||
serviceAccountSecret string
|
||||
providerRegistryAddr string
|
||||
manager share.Manager
|
||||
loader share.LoadableManager
|
||||
}
|
||||
|
||||
type Migrations struct {
|
||||
config
|
||||
state state
|
||||
instanceID string
|
||||
}
|
||||
|
||||
var migrations []migration
|
||||
|
||||
// registerMigration is only supposed to be call from init(), which runs sequentially
|
||||
// so we don't need ot protect migrations with a lock
|
||||
func registerMigration(m migration) {
|
||||
migrations = append(migrations, m)
|
||||
}
|
||||
|
||||
func New(logger zerolog.Logger,
|
||||
gatewaySelector pool.Selectable[gatewayv1beta1.GatewayAPIClient],
|
||||
storage metadata.Storage,
|
||||
cfg MigrationConfig,
|
||||
manager share.Manager,
|
||||
loader share.LoadableManager,
|
||||
) Migrations {
|
||||
|
||||
slices.SortFunc(migrations, func(a, b migration) int {
|
||||
return cmp.Compare(a.Version(), b.Version())
|
||||
})
|
||||
|
||||
b := make([]byte, 8)
|
||||
_, _ = rand.Read(b)
|
||||
instanceID := fmt.Sprintf("%x", b)
|
||||
|
||||
return Migrations{
|
||||
config{
|
||||
logger: logger.With().Str("jsoncs3", "migrations").Logger(),
|
||||
gatewaySelector: gatewaySelector,
|
||||
storage: storage,
|
||||
serviceAccountID: cfg.ServiceAccountID,
|
||||
serviceAccountSecret: cfg.ServiceAccountSecret,
|
||||
providerRegistryAddr: cfg.ProviderRegistryAddr,
|
||||
manager: manager,
|
||||
loader: loader,
|
||||
},
|
||||
state{},
|
||||
instanceID,
|
||||
}
|
||||
}
|
||||
|
||||
// acquireLock tries to atomically create the lock file, blocking until the lock
|
||||
// is obtained. It returns the etag of the lock file on success. It retries
|
||||
// indefinitely until ctx is cancelled. A lock whose timestamp is older than
|
||||
// lockTTL is considered stale and will be taken over.
|
||||
func (m *Migrations) acquireLock(ctx context.Context) (string, error) {
|
||||
m.logger.Debug().Str("instance", m.instanceID).Msg("acquiring migration lock")
|
||||
for {
|
||||
// Fast path: create the lock file only if it does not exist yet.
|
||||
data, err := json.Marshal(lockData{Timestamp: time.Now(), InstanceID: m.instanceID})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
res, err := m.storage.Upload(ctx, metadata.UploadRequest{
|
||||
Path: lockFile,
|
||||
Content: data,
|
||||
IfNoneMatch: []string{"*"},
|
||||
})
|
||||
if err == nil {
|
||||
m.logger.Debug().Str("instance", m.instanceID).Msg("migration lock acquired")
|
||||
return res.Etag, nil
|
||||
}
|
||||
|
||||
// Propagate context cancellation immediately.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// Any error other than a conflict means something unexpected happened.
|
||||
if !isConflict(err) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Lock file already exists — read it to decide whether it is stale.
|
||||
dl, err := m.storage.Download(ctx, metadata.DownloadRequest{Path: lockFile})
|
||||
if err != nil {
|
||||
if _, ok := err.(errtypes.IsNotFound); ok {
|
||||
// Lock was released between our upload attempt and the download;
|
||||
// retry acquiring it immediately.
|
||||
m.logger.Debug().Str("instance", m.instanceID).Msg("migration lock vanished during read; retrying")
|
||||
continue
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
var existing lockData
|
||||
stale := true
|
||||
if err := json.Unmarshal(dl.Content, &existing); err == nil {
|
||||
stale = time.Since(existing.Timestamp) > lockTTL
|
||||
}
|
||||
|
||||
if stale {
|
||||
m.logger.Debug().
|
||||
Str("instance", m.instanceID).
|
||||
Str("held_by", existing.InstanceID).
|
||||
Time("lock_timestamp", existing.Timestamp).
|
||||
Msg("migration lock is stale; attempting takeover")
|
||||
|
||||
// Atomically take over the stale lock using the etag we just read.
|
||||
newData, err := json.Marshal(lockData{Timestamp: time.Now(), InstanceID: m.instanceID})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
res, err := m.storage.Upload(ctx, metadata.UploadRequest{
|
||||
Path: lockFile,
|
||||
Content: newData,
|
||||
IfMatchEtag: dl.Etag,
|
||||
})
|
||||
if err == nil {
|
||||
m.logger.Debug().Str("instance", m.instanceID).Msg("migration lock acquired via stale takeover")
|
||||
return res.Etag, nil
|
||||
}
|
||||
// Another instance took the stale lock before us; loop and retry.
|
||||
m.logger.Debug().Str("instance", m.instanceID).Err(err).Msg("stale lock takeover lost race; retrying")
|
||||
continue
|
||||
}
|
||||
|
||||
m.logger.Debug().
|
||||
Str("instance", m.instanceID).
|
||||
Str("held_by", existing.InstanceID).
|
||||
Time("lock_timestamp", existing.Timestamp).
|
||||
Dur("poll_interval", lockPollInterval).
|
||||
Msg("migration lock held by another instance; waiting")
|
||||
|
||||
// Lock is fresh and held by another instance; wait before retrying.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
case <-time.After(lockPollInterval):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startHeartbeat spawns a goroutine that periodically renews the lock file so
|
||||
// that it is not considered stale while a long migration is running. Call the
|
||||
// returned cancel function to stop the heartbeat.
|
||||
func (m *Migrations) startHeartbeat(ctx context.Context, etag string) context.CancelFunc {
|
||||
hbCtx, cancel := context.WithCancel(ctx)
|
||||
go func() {
|
||||
ticker := time.NewTicker(lockHeartbeatInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-hbCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
data, err := json.Marshal(lockData{Timestamp: time.Now(), InstanceID: m.instanceID})
|
||||
if err != nil {
|
||||
m.logger.Warn().Err(err).Msg("failed to marshal heartbeat data for migration lock")
|
||||
return
|
||||
}
|
||||
res, err := m.storage.Upload(hbCtx, metadata.UploadRequest{
|
||||
Path: lockFile,
|
||||
Content: data,
|
||||
IfMatchEtag: etag,
|
||||
})
|
||||
if err != nil {
|
||||
m.logger.Warn().Err(err).Msg("failed to renew migration lock; another instance may take over")
|
||||
return
|
||||
}
|
||||
etag = res.Etag
|
||||
}
|
||||
}
|
||||
}()
|
||||
return cancel
|
||||
}
|
||||
|
||||
// releaseLock deletes the lock file unconditionally.
|
||||
func (m *Migrations) releaseLock(ctx context.Context) {
|
||||
if err := m.storage.Delete(ctx, lockFile); err != nil {
|
||||
m.logger.Warn().Err(err).Msg("failed to release migration lock")
|
||||
}
|
||||
}
|
||||
|
||||
// isConflict returns true for errors that signal a conditional-upload conflict,
|
||||
// i.e. the lock file already exists or the etag did not match.
|
||||
func isConflict(err error) bool {
|
||||
switch err.(type) {
|
||||
case errtypes.IsAlreadyExists, errtypes.IsAborted, errtypes.IsPreconditionFailed:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// loadState reads the persisted migration version from storage. If no state
|
||||
// file exists yet (fresh deployment) it returns version 0 without error.
|
||||
func (m *Migrations) loadState(ctx context.Context) error {
|
||||
data, err := m.storage.SimpleDownload(ctx, stateFile)
|
||||
if err != nil {
|
||||
if _, ok := err.(errtypes.IsNotFound); ok {
|
||||
m.state = state{version: 0}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
var ps persistedState
|
||||
if err := json.Unmarshal(data, &ps); err != nil {
|
||||
return err
|
||||
}
|
||||
m.state = state{version: ps.Version}
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveState writes the current migration version to storage so that already-
|
||||
// applied migrations are not re-run on the next server start.
|
||||
func (m *Migrations) saveState(ctx context.Context) error {
|
||||
data, err := json.Marshal(persistedState{Version: m.state.version})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return m.storage.SimpleUpload(ctx, stateFile, data)
|
||||
}
|
||||
|
||||
func (m *Migrations) RunMigrations() {
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
etag, err := m.acquireLock(ctx)
|
||||
if err != nil {
|
||||
m.logger.Error().Err(err).Msg("failed to acquire migration lock; skipping migrations")
|
||||
return
|
||||
}
|
||||
cancelHB := m.startHeartbeat(ctx, etag)
|
||||
defer cancelHB()
|
||||
defer m.releaseLock(ctx)
|
||||
|
||||
if err := m.loadState(ctx); err != nil {
|
||||
m.logger.Error().Err(err).Msg("failed to load migration state; skipping migrations")
|
||||
return
|
||||
}
|
||||
|
||||
m.logger.Info().Int("current state", m.state.version).Msg("checking migrations")
|
||||
|
||||
for _, mig := range migrations {
|
||||
if mig.Version() > m.state.version {
|
||||
m.logger.Info().Str("migration", mig.Name()).Int("version", mig.Version()).Msg("running migration")
|
||||
mig.Initialize(m.config)
|
||||
if err := mig.Migrate(); err != nil {
|
||||
m.logger.Error().Err(err).Str("migration", mig.Name()).Msg("migration failed; stopping")
|
||||
return
|
||||
}
|
||||
m.state.version = mig.Version()
|
||||
if err := m.saveState(ctx); err != nil {
|
||||
m.logger.Error().Err(err).Msg("failed to save migration state; stopping")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
m.logger.Info().Str("migration", mig.Name()).Int("version", mig.Version()).Msg("skipping migration")
|
||||
}
|
||||
}
|
||||
}
|
||||
+2
-1
@@ -28,6 +28,7 @@ import (
|
||||
|
||||
ctxpkg "github.com/opencloud-eu/reva/v2/pkg/ctx"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/share"
|
||||
"github.com/rs/zerolog"
|
||||
"google.golang.org/genproto/protobuf/field_mask"
|
||||
|
||||
userv1beta1 "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
|
||||
@@ -46,7 +47,7 @@ func init() {
|
||||
}
|
||||
|
||||
// New returns a new manager.
|
||||
func New(c map[string]interface{}) (share.Manager, error) {
|
||||
func New(c map[string]any, _ *zerolog.Logger) (share.Manager, error) {
|
||||
state := map[string]map[*collaboration.ShareId]collaboration.ShareState{}
|
||||
mp := map[string]map[*collaboration.ShareId]*provider.Reference{}
|
||||
return &manager{
|
||||
|
||||
+5
-2
@@ -18,11 +18,14 @@
|
||||
|
||||
package registry
|
||||
|
||||
import "github.com/opencloud-eu/reva/v2/pkg/share"
|
||||
import (
|
||||
"github.com/opencloud-eu/reva/v2/pkg/share"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// NewFunc is the function that share managers
|
||||
// should register at init time.
|
||||
type NewFunc func(map[string]interface{}) (share.Manager, error)
|
||||
type NewFunc func(map[string]any, *zerolog.Logger) (share.Manager, error)
|
||||
|
||||
// NewFuncs is a map containing all the registered share managers.
|
||||
var NewFuncs = map[string]NewFunc{}
|
||||
|
||||
+1
-1
@@ -72,7 +72,7 @@ func NewNatsKeyValueFromJetStream(c Config, js jetstream.JetStream) (jetstream.K
|
||||
if err != nil {
|
||||
kvConfig := jetstream.KeyValueConfig{
|
||||
Bucket: c.Database,
|
||||
TTL: 0, // we don't do TTLs for this store
|
||||
TTL: c.TTL,
|
||||
}
|
||||
if c.DisablePersistence {
|
||||
kvConfig.Storage = jetstream.MemoryStorage
|
||||
|
||||
+5
-6
@@ -65,16 +65,16 @@ func (c *IDCache) DeleteByPath(ctx context.Context, path string) error {
|
||||
} else {
|
||||
err := c.kv.Purge(ctx, baseKey)
|
||||
if err != nil && err != nats.ErrKeyNotFound {
|
||||
appctx.GetLogger(ctx).Error().Err(err).Str("record", path).Str("spaceID", spaceID).Str("nodeID", nodeID).Msg("could not get spaceID and nodeID from cache")
|
||||
appctx.GetLogger(ctx).Error().Err(err).Str("record", baseKey).Str("spaceID", spaceID).Str("nodeID", nodeID).Msg("could not purge from cache")
|
||||
}
|
||||
|
||||
err = c.kv.Purge(ctx, cacheKey(spaceID, nodeID))
|
||||
if err != nil && err != nats.ErrKeyNotFound {
|
||||
appctx.GetLogger(ctx).Error().Err(err).Str("record", path).Str("spaceID", spaceID).Str("nodeID", nodeID).Msg("could not get spaceID and nodeID from cache")
|
||||
appctx.GetLogger(ctx).Error().Err(err).Str("record", cacheKey(spaceID, nodeID)).Str("spaceID", spaceID).Str("nodeID", nodeID).Msg("could not purge from cache")
|
||||
}
|
||||
}
|
||||
|
||||
watcher, err := c.kv.Watch(ctx, baseKey+".*")
|
||||
watcher, err := c.kv.Watch(ctx, baseKey+".>")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -85,7 +85,6 @@ func (c *IDCache) DeleteByPath(ctx context.Context, path string) error {
|
||||
break
|
||||
}
|
||||
key := update.Key()
|
||||
|
||||
spaceID, nodeID, ok := c.getByReverseCacheKey(ctx, key)
|
||||
if !ok {
|
||||
appctx.GetLogger(ctx).Error().Str("record", key).Msg("could not get spaceID and nodeID from cache")
|
||||
@@ -94,12 +93,12 @@ func (c *IDCache) DeleteByPath(ctx context.Context, path string) error {
|
||||
|
||||
err := c.kv.Purge(ctx, key)
|
||||
if err != nil && err != nats.ErrKeyNotFound {
|
||||
appctx.GetLogger(ctx).Error().Err(err).Str("record", key).Str("spaceID", spaceID).Str("nodeID", nodeID).Msg("could not get spaceID and nodeID from cache")
|
||||
appctx.GetLogger(ctx).Error().Err(err).Str("record", key).Str("spaceID", spaceID).Str("nodeID", nodeID).Msg("could not purge from cache")
|
||||
}
|
||||
|
||||
err = c.kv.Purge(ctx, cacheKey(spaceID, nodeID))
|
||||
if err != nil && err != nats.ErrKeyNotFound {
|
||||
appctx.GetLogger(ctx).Error().Err(err).Str("record", key).Str("spaceID", spaceID).Str("nodeID", nodeID).Msg("could not get spaceID and nodeID from cache")
|
||||
appctx.GetLogger(ctx).Error().Err(err).Str("record", cacheKey(spaceID, nodeID)).Str("spaceID", spaceID).Str("nodeID", nodeID).Msg("could not purge from cache")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
+11
-1
@@ -24,6 +24,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
tusd "github.com/tus/tusd/v2/pkg/handler"
|
||||
@@ -58,7 +59,8 @@ func init() {
|
||||
type posixFS struct {
|
||||
storage.FS
|
||||
|
||||
um usermapper.Mapper
|
||||
tree *tree.Tree
|
||||
um usermapper.Mapper
|
||||
}
|
||||
|
||||
// New returns an implementation to of the storage.FS interface that talk to
|
||||
@@ -70,6 +72,7 @@ func NewDefault(m map[string]interface{}, stream events.Stream, log *zerolog.Log
|
||||
}
|
||||
|
||||
o.IDCache.Database += "_v2" // Use a versioned bucket name to avoid conflicts with previous implementations
|
||||
o.IDCache.TTL = 0 // Disable TTL for the ID cache, as the posix driver relies on it for caching file IDs and we don't want them to expire
|
||||
kv, err := cache.NewNatsKeyValue(o.IDCache)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not create nats key value store")
|
||||
@@ -80,6 +83,7 @@ func NewDefault(m map[string]interface{}, stream events.Stream, log *zerolog.Log
|
||||
}
|
||||
|
||||
o.IDCache.Database += "_history" // Use a versioned bucket name to avoid conflicts with previous implementations
|
||||
o.IDCache.TTL = 24 * 60 * time.Minute
|
||||
historyKv, err := cache.NewNatsKeyValue(o.IDCache)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not create nats key value store")
|
||||
@@ -215,11 +219,17 @@ func New(o *options.Options, stream events.Stream, cache, historyCache *idcache.
|
||||
|
||||
mw := middleware.NewFS(dfs, hooks...)
|
||||
fs.FS = mw
|
||||
fs.tree = tp
|
||||
fs.um = um
|
||||
|
||||
return fs, nil
|
||||
}
|
||||
|
||||
// WarmupIDCache allows triggering a posix fs scan and id cache warmup manually.
|
||||
func (fs *posixFS) WarmupIDCache(root string, assimilate, onlyDirty bool) error {
|
||||
return fs.tree.WarmupIDCache(root, assimilate, onlyDirty)
|
||||
}
|
||||
|
||||
// ListUploadSessions returns the upload sessions matching the given filter
|
||||
func (fs *posixFS) ListUploadSessions(ctx context.Context, filter storage.UploadSessionFilter) ([]storage.UploadSession, error) {
|
||||
return fs.FS.(storage.UploadSessionLister).ListUploadSessions(ctx, filter)
|
||||
|
||||
-15
@@ -37,10 +37,8 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
user "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
|
||||
provider "github.com/cs3org/go-cs3apis/cs3/storage/provider/v1beta1"
|
||||
|
||||
"github.com/opencloud-eu/reva/v2/pkg/appctx"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/errtypes"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/events"
|
||||
"github.com/opencloud-eu/reva/v2/pkg/storage/fs/posix/blobstore"
|
||||
@@ -590,11 +588,6 @@ func (t *Tree) Delete(ctx context.Context, n *node.Node) error {
|
||||
}
|
||||
}()
|
||||
|
||||
if appctx.DeletingSharedResourceFromContext(ctx) {
|
||||
src := filepath.Join(n.ParentPath(), n.Name)
|
||||
return os.RemoveAll(src)
|
||||
}
|
||||
|
||||
var sizeDiff int64
|
||||
if n.IsDir(ctx) {
|
||||
treesize, err := n.GetTreeSize(ctx)
|
||||
@@ -819,11 +812,3 @@ func isLockFile(path string) bool {
|
||||
func isTrash(path string) bool {
|
||||
return strings.HasSuffix(path, ".trashinfo") || strings.HasSuffix(path, ".trashitem") || strings.Contains(path, ".Trash")
|
||||
}
|
||||
|
||||
func (t *Tree) AddLabel(ctx context.Context, ref *provider.Reference, userID *user.UserId, label string) error {
|
||||
return errtypes.NotSupported("AddLabel not implemented")
|
||||
}
|
||||
|
||||
func (t *Tree) RemoveLabel(ctx context.Context, ref *provider.Reference, userID *user.UserId, label string) error {
|
||||
return errtypes.NotSupported("RemoveLabel not implemented")
|
||||
}
|
||||
|
||||
Generated
Vendored
+1
@@ -101,6 +101,7 @@ func ServiceAccountPermissions() *provider.ResourcePermissions {
|
||||
Delete: true, // for cli restore command with replace option
|
||||
CreateContainer: true, // for space provisioning
|
||||
AddGrant: true, // for initial project space member assignment
|
||||
ListGrants: true, // for initial project space member assignment
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
-5
@@ -429,11 +429,6 @@ func (t *Tree) Delete(ctx context.Context, n *node.Node) (err error) {
|
||||
// remove entry from cache immediately to avoid inconsistencies
|
||||
defer func() { _ = t.idCache.Delete(path) }()
|
||||
|
||||
if appctx.DeletingSharedResourceFromContext(ctx) {
|
||||
src := filepath.Join(n.ParentPath(), n.Name)
|
||||
return os.Remove(src)
|
||||
}
|
||||
|
||||
// get the original path
|
||||
origin, err := t.lookup.Path(ctx, n, node.NoCheck)
|
||||
if err != nil {
|
||||
|
||||
Generated
Vendored
-5
@@ -445,11 +445,6 @@ func (t *Tree) Delete(ctx context.Context, n *node.Node) (err error) {
|
||||
// remove entry from cache immediately to avoid inconsistencies
|
||||
defer func() { _ = t.idCache.Delete(path) }()
|
||||
|
||||
if appctx.DeletingSharedResourceFromContext(ctx) {
|
||||
src := filepath.Join(n.ParentPath(), n.Name)
|
||||
return os.Remove(src)
|
||||
}
|
||||
|
||||
// get the original path
|
||||
origin, err := t.lookup.Path(ctx, n, node.NoCheck)
|
||||
if err != nil {
|
||||
|
||||
+38
-1
@@ -93,6 +93,40 @@ func (disk *Disk) SimpleUpload(ctx context.Context, uploadpath string, content [
|
||||
// Upload stores a file on disk
|
||||
func (disk *Disk) Upload(_ context.Context, req UploadRequest) (*UploadResponse, error) {
|
||||
p := disk.targetPath(req.Path)
|
||||
|
||||
// IfNoneMatch: ["*"] means create the file only if it does not already
|
||||
// exist. Use O_EXCL so the check and the create are atomic on the local
|
||||
// filesystem.
|
||||
for _, tag := range req.IfNoneMatch {
|
||||
if tag != "*" {
|
||||
continue
|
||||
}
|
||||
f, err := os.OpenFile(p, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0644)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrExist) {
|
||||
return nil, errtypes.AlreadyExists(p)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if _, err := f.Write(req.Content); err != nil {
|
||||
_ = f.Close()
|
||||
return nil, err
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info, err := os.Stat(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res := &UploadResponse{}
|
||||
res.Etag, err = calcEtag(info.ModTime(), info.Size())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
if req.IfMatchEtag != "" {
|
||||
info, err := os.Stat(p)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
@@ -170,7 +204,10 @@ func (disk *Disk) Download(_ context.Context, req DownloadRequest) (*DownloadRes
|
||||
// SimpleDownload reads a file from disk
|
||||
func (disk *Disk) SimpleDownload(ctx context.Context, downloadpath string) ([]byte, error) {
|
||||
res, err := disk.Download(ctx, DownloadRequest{Path: downloadpath})
|
||||
return res.Content, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.Content, nil
|
||||
}
|
||||
|
||||
// Delete deletes a path
|
||||
|
||||
+48
-41
@@ -266,15 +266,10 @@ func (i *Identity) GetLDAPUserByFilter(ctx context.Context, lc ldap.Client, filt
|
||||
res, err := lc.Search(searchRequest)
|
||||
if err != nil {
|
||||
log.Debug().Str("backend", "ldap").Err(err).Str("userfilter", filter).Msg("Error looking up user by filter")
|
||||
var errmsg string
|
||||
if lerr, ok := err.(*ldap.Error); ok {
|
||||
if lerr.ResultCode == ldap.LDAPResultSizeLimitExceeded {
|
||||
errmsg = fmt.Sprintf("too many results searching for user '%s'", filter)
|
||||
}
|
||||
}
|
||||
span.SetAttributes(attribute.String("ldap.error", errmsg))
|
||||
span.SetStatus(codes.Error, errmsg)
|
||||
return nil, errtypes.NotFound(errmsg)
|
||||
classified := classifySearchError(err, fmt.Sprintf("too many results searching for user '%s'", filter))
|
||||
span.SetAttributes(attribute.String("ldap.error", classified.Error()))
|
||||
span.SetStatus(codes.Error, classified.Error())
|
||||
return nil, classified
|
||||
}
|
||||
if len(res.Entries) == 0 {
|
||||
return nil, errtypes.NotFound(filter)
|
||||
@@ -306,9 +301,10 @@ func (i *Identity) GetLDAPUserByDN(ctx context.Context, lc ldap.Client, dn strin
|
||||
res, err := lc.Search(searchRequest)
|
||||
if err != nil {
|
||||
log.Debug().Str("backend", "ldap").Err(err).Str("dn", dn).Msg("Error looking up user by DN")
|
||||
span.SetAttributes(attribute.String("ldap.error", err.Error()))
|
||||
span.SetStatus(codes.Error, "")
|
||||
return nil, errtypes.NotFound(dn)
|
||||
classified := classifySearchError(err, "")
|
||||
span.SetAttributes(attribute.String("ldap.error", classified.Error()))
|
||||
span.SetStatus(codes.Error, classified.Error())
|
||||
return nil, classified
|
||||
}
|
||||
span.SetStatus(codes.Ok, "")
|
||||
if len(res.Entries) == 0 {
|
||||
@@ -337,9 +333,10 @@ func (i *Identity) GetLDAPUsers(ctx context.Context, lc ldap.Client, query, tena
|
||||
sr, err := lc.Search(searchRequest)
|
||||
if err != nil {
|
||||
log.Debug().Str("backend", "ldap").Err(err).Str("filter", filter).Msg("Error searching users")
|
||||
span.SetAttributes(attribute.String("ldap.error", err.Error()))
|
||||
span.SetStatus(codes.Error, "")
|
||||
return nil, errtypes.NotFound(query)
|
||||
classified := classifySearchError(err, "")
|
||||
span.SetAttributes(attribute.String("ldap.error", classified.Error()))
|
||||
span.SetStatus(codes.Error, classified.Error())
|
||||
return nil, classified
|
||||
}
|
||||
|
||||
span.SetAttributes(attribute.Int("ldap.result_count", len(sr.Entries)))
|
||||
@@ -376,7 +373,8 @@ func (i *Identity) IsLDAPUserInDisabledGroup(ctx context.Context, lc ldap.Client
|
||||
sr, err := lc.Search(searchRequest)
|
||||
if err != nil {
|
||||
log.Error().Str("backend", "ldap").Err(err).Str("filter", filter).Msg("Error looking up error group")
|
||||
// Err on the side of caution.
|
||||
// Err on the side of caution: treat search failures (including network
|
||||
// errors) as if the user is in the disabled group.
|
||||
span.SetAttributes(attribute.String("ldap.error", err.Error()))
|
||||
span.SetStatus(codes.Error, "")
|
||||
return true
|
||||
@@ -423,10 +421,10 @@ func (i *Identity) GetLDAPUserGroups(ctx context.Context, lc ldap.Client, userEn
|
||||
// not having any groups in LDAP
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
span.SetAttributes(attribute.String("ldap.error", err.Error()))
|
||||
span.SetStatus(codes.Error, "")
|
||||
return []string{}, err
|
||||
classified := classifySearchError(err, "")
|
||||
span.SetAttributes(attribute.String("ldap.error", classified.Error()))
|
||||
span.SetStatus(codes.Error, classified.Error())
|
||||
return nil, classified
|
||||
}
|
||||
span.SetStatus(codes.Ok, "")
|
||||
span.SetAttributes(attribute.Int("ldap.result_count", len(sr.Entries)))
|
||||
@@ -504,15 +502,10 @@ func (i *Identity) GetLDAPGroupByFilter(ctx context.Context, lc ldap.Client, fil
|
||||
res, err := lc.Search(searchRequest)
|
||||
if err != nil {
|
||||
log.Debug().Str("backend", "ldap").Err(err).Str("filter", filter).Msg("Error looking up group by filter")
|
||||
var errmsg string
|
||||
if lerr, ok := err.(*ldap.Error); ok {
|
||||
if lerr.ResultCode == ldap.LDAPResultSizeLimitExceeded {
|
||||
errmsg = fmt.Sprintf("too many results searching for group '%s'", filter)
|
||||
}
|
||||
}
|
||||
span.SetAttributes(attribute.String("ldap.error", errmsg))
|
||||
span.SetStatus(codes.Error, "")
|
||||
return nil, errtypes.NotFound(errmsg)
|
||||
classified := classifySearchError(err, fmt.Sprintf("too many results searching for group '%s'", filter))
|
||||
span.SetAttributes(attribute.String("ldap.error", classified.Error()))
|
||||
span.SetStatus(codes.Error, classified.Error())
|
||||
return nil, classified
|
||||
}
|
||||
if len(res.Entries) == 0 {
|
||||
return nil, errtypes.NotFound(filter)
|
||||
@@ -543,10 +536,11 @@ func (i *Identity) GetLDAPGroups(ctx context.Context, lc ldap.Client, query stri
|
||||
setLDAPSearchSpanAttributes(span, searchRequest)
|
||||
sr, err := lc.Search(searchRequest)
|
||||
if err != nil {
|
||||
span.SetAttributes(attribute.String("ldap.error", err.Error()))
|
||||
span.SetStatus(codes.Error, "")
|
||||
log.Debug().Str("backend", "ldap").Err(err).Str("query", query).Msg("Error search for groups")
|
||||
return nil, errtypes.NotFound(query)
|
||||
classified := classifySearchError(err, "")
|
||||
span.SetAttributes(attribute.String("ldap.error", classified.Error()))
|
||||
span.SetStatus(codes.Error, classified.Error())
|
||||
return nil, classified
|
||||
}
|
||||
span.SetStatus(codes.Ok, "")
|
||||
return sr.Entries, nil
|
||||
@@ -919,15 +913,10 @@ func (i *Identity) GetLDAPTenantByFilter(ctx context.Context, lc ldap.Client, fi
|
||||
res, err := lc.Search(searchRequest)
|
||||
if err != nil {
|
||||
log.Debug().Str("backend", "ldap").Err(err).Str("tenantfilter", filter).Msg("Error looking up tenant by filter")
|
||||
var errmsg string
|
||||
if lerr, ok := err.(*ldap.Error); ok {
|
||||
if lerr.ResultCode == ldap.LDAPResultSizeLimitExceeded {
|
||||
errmsg = fmt.Sprintf("too many results searching for tenant '%s'", filter)
|
||||
}
|
||||
}
|
||||
span.SetAttributes(attribute.String("ldap.error", errmsg))
|
||||
span.SetStatus(codes.Error, errmsg)
|
||||
return nil, errtypes.NotFound(errmsg)
|
||||
classified := classifySearchError(err, fmt.Sprintf("too many results searching for tenant '%s'", filter))
|
||||
span.SetAttributes(attribute.String("ldap.error", classified.Error()))
|
||||
span.SetStatus(codes.Error, classified.Error())
|
||||
return nil, classified
|
||||
}
|
||||
if len(res.Entries) == 0 {
|
||||
return nil, errtypes.NotFound(filter)
|
||||
@@ -980,6 +969,24 @@ func (i *Identity) getTenantAttributeFilter(attribute, value string) (string, er
|
||||
), nil
|
||||
}
|
||||
|
||||
// classifySearchError maps a raw error from lc.Search to the appropriate
|
||||
// errtypes value:
|
||||
// - ldap.ErrorNetwork → errtypes.Unavailable (transient; caller should retry)
|
||||
// - ldap.LDAPResultSizeLimitExceeded → errtypes.NotFound(sizeExceededMsg)
|
||||
// - anything else → errtypes.NotFound("") (preserving prior behaviour)
|
||||
//
|
||||
// The sizeExceededMsg is only used for the SizeLimitExceeded case; pass an
|
||||
// empty string if the caller does not need a custom message for that case.
|
||||
func classifySearchError(err error, sizeExceededMsg string) error {
|
||||
if ldap.IsErrorWithCode(err, ldap.ErrorNetwork) {
|
||||
return errtypes.Unavailable("ldap server unreachable: " + err.Error())
|
||||
}
|
||||
if sizeExceededMsg != "" && ldap.IsErrorWithCode(err, ldap.LDAPResultSizeLimitExceeded) {
|
||||
return errtypes.NotFound(sizeExceededMsg)
|
||||
}
|
||||
return errtypes.NotFound("")
|
||||
}
|
||||
|
||||
func setLDAPSearchSpanAttributes(span trace.Span, request *ldap.SearchRequest) {
|
||||
span.SetAttributes(
|
||||
attribute.String("ldap.basedn", request.BaseDN),
|
||||
|
||||
+1
-1
@@ -586,7 +586,7 @@ func (c *cbcCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader
|
||||
|
||||
// Length of encrypted portion of the packet (header, payload, padding).
|
||||
// Enforce minimum padding and packet size.
|
||||
encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPaddingSize)
|
||||
encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPacketSize)
|
||||
// Enforce block size.
|
||||
encLength = (encLength + effectiveBlockSize - 1) / effectiveBlockSize * effectiveBlockSize
|
||||
|
||||
|
||||
+7
-3
@@ -274,10 +274,14 @@ func pickSignatureAlgorithm(signer Signer, extensions map[string][]byte) (MultiA
|
||||
}
|
||||
|
||||
// Filter algorithms based on those supported by MultiAlgorithmSigner.
|
||||
// Iterate over the signer's algorithms first to preserve its preference order.
|
||||
supportedKeyAlgos := algorithmsForKeyFormat(keyFormat)
|
||||
var keyAlgos []string
|
||||
for _, algo := range algorithmsForKeyFormat(keyFormat) {
|
||||
if slices.Contains(as.Algorithms(), underlyingAlgo(algo)) {
|
||||
keyAlgos = append(keyAlgos, algo)
|
||||
for _, signerAlgo := range as.Algorithms() {
|
||||
if idx := slices.IndexFunc(supportedKeyAlgos, func(algo string) bool {
|
||||
return underlyingAlgo(algo) == signerAlgo
|
||||
}); idx >= 0 {
|
||||
keyAlgos = append(keyAlgos, supportedKeyAlgos[idx])
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+32
-1
@@ -61,13 +61,42 @@ func (r *responseDeduper) addAll(dr *DriverResponse) {
|
||||
}
|
||||
|
||||
func (r *responseDeduper) addPackage(p *Package) {
|
||||
if r.seenPackages[p.ID] != nil {
|
||||
if prev := r.seenPackages[p.ID]; prev != nil {
|
||||
// Package already seen in a previous response. Merge the file lists,
|
||||
// removing duplicates. This can happen when the same package appears
|
||||
// in multiple driver responses that are being merged together.
|
||||
prev.GoFiles = appendUniqueStrings(prev.GoFiles, p.GoFiles)
|
||||
prev.CompiledGoFiles = appendUniqueStrings(prev.CompiledGoFiles, p.CompiledGoFiles)
|
||||
prev.OtherFiles = appendUniqueStrings(prev.OtherFiles, p.OtherFiles)
|
||||
prev.IgnoredFiles = appendUniqueStrings(prev.IgnoredFiles, p.IgnoredFiles)
|
||||
prev.EmbedFiles = appendUniqueStrings(prev.EmbedFiles, p.EmbedFiles)
|
||||
prev.EmbedPatterns = appendUniqueStrings(prev.EmbedPatterns, p.EmbedPatterns)
|
||||
return
|
||||
}
|
||||
r.seenPackages[p.ID] = p
|
||||
r.dr.Packages = append(r.dr.Packages, p)
|
||||
}
|
||||
|
||||
// appendUniqueStrings appends elements from src to dst, skipping duplicates.
|
||||
func appendUniqueStrings(dst, src []string) []string {
|
||||
if len(src) == 0 {
|
||||
return dst
|
||||
}
|
||||
|
||||
seen := make(map[string]bool, len(dst))
|
||||
for _, s := range dst {
|
||||
seen[s] = true
|
||||
}
|
||||
|
||||
for _, s := range src {
|
||||
if !seen[s] {
|
||||
dst = append(dst, s)
|
||||
}
|
||||
}
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (r *responseDeduper) addRoot(id string) {
|
||||
if r.seenRoots[id] {
|
||||
return
|
||||
@@ -832,6 +861,8 @@ func golistargs(cfg *Config, words []string, goVersion int) []string {
|
||||
// go list doesn't let you pass -test and -find together,
|
||||
// probably because you'd just get the TestMain.
|
||||
fmt.Sprintf("-find=%t", !cfg.Tests && cfg.Mode&findFlags == 0 && !usesExportData(cfg)),
|
||||
// VCS information is not needed when not printing Stale or StaleReason fields
|
||||
"-buildvcs=false",
|
||||
}
|
||||
|
||||
// golang/go#60456: with go1.21 and later, go list serves pgo variants, which
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user