diff --git a/go.mod b/go.mod index 70c331e522..1d273828af 100644 --- a/go.mod +++ b/go.mod @@ -37,13 +37,13 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 - github.com/stretchr/testify v1.8.2 + github.com/stretchr/testify v1.8.4 github.com/tetratelabs/wazero v1.3.1 github.com/things-go/go-socks5 v0.0.3 github.com/xlab/treeprint v1.2.0 github.com/yiya1989/sshkrb5 v0.0.0-20201110125252-a1455b75a35e golang.org/x/crypto v0.14.0 - golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 + golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 golang.org/x/net v0.17.0 golang.org/x/sys v0.13.0 golang.org/x/term v0.13.0 @@ -59,7 +59,7 @@ require ( gorm.io/gorm v1.25.2 gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f modernc.org/sqlite v1.23.1 - tailscale.com v1.44.0 + tailscale.com v1.50.1 ) require ( @@ -88,7 +88,9 @@ require ( github.com/aws/smithy-go v1.13.5 // indirect github.com/chromedp/sysutil v1.0.0 // indirect github.com/coreos/go-iptables v0.6.0 // indirect + github.com/coreos/go-systemd/v22 v22.4.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dblohm7/wingoes v0.0.0-20230821191801-fc76608aecf0 // indirect github.com/demisto/goxforce v0.0.0-20160322194047-db8357535b1d // indirect github.com/dlclark/regexp2 v1.4.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect @@ -101,11 +103,12 @@ require ( github.com/gobwas/httphead v0.1.0 // indirect github.com/gobwas/pool v0.2.1 // indirect github.com/gobwas/ws v1.1.0 // indirect - github.com/godbus/dbus/v5 v5.1.0 // indirect + github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/btree v1.1.2 // indirect github.com/google/go-cmp v0.5.9 // indirect + github.com/google/nftables v0.1.1-0.20230115205135-9aa6fdf5a28c // indirect github.com/hashicorp/go-uuid v1.0.2 // indirect github.com/hdevalence/ed25519consensus v0.1.0 // indirect github.com/illarion/gonotify v1.0.1 // indirect @@ -141,10 +144,10 @@ require ( github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.4 // indirect github.com/tailscale/certstore v0.1.1-0.20220316223106-78d6e1c49d8d // indirect - github.com/tailscale/golang-x-crypto v0.0.0-20221115211329-17a3db2c30d2 // indirect + github.com/tailscale/golang-x-crypto v0.0.0-20230713185742-f0b76a10a08e // indirect github.com/tailscale/goupnp v1.0.1-0.20210804011211-c64d0f06ea05 // indirect github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85 // indirect - github.com/tailscale/wireguard-go v0.0.0-20230410165232-af172621b4dd // indirect + github.com/tailscale/wireguard-go v0.0.0-20230824215414-93bd5cbf7fd8 // indirect github.com/tcnksm/go-httpstat v0.2.0 // indirect github.com/thedevsaddam/gojsonq/v2 v2.5.2 // indirect github.com/u-root/uio v0.0.0-20230305220412-3e8cd9d6bf63 // indirect @@ -152,8 +155,8 @@ require ( github.com/vishvananda/netns v0.0.4 // indirect github.com/x448/float16 v0.8.4 // indirect go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect - go4.org/netipx v0.0.0-20230303233057-f1b76eb4bb35 // indirect - golang.org/x/mod v0.10.0 // indirect + go4.org/netipx v0.0.0-20230728180743-ad4cb58a6516 // indirect + golang.org/x/mod v0.11.0 // indirect golang.org/x/sync v0.2.0 // indirect golang.org/x/time v0.3.0 // indirect golang.org/x/tools v0.9.1 // indirect diff --git a/go.sum b/go.sum index d53394e860..cd5758672c 100644 --- a/go.sum +++ b/go.sum @@ -91,13 +91,16 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/coreos/go-iptables v0.6.0 h1:is9qnZMPYjLd8LYqmm/qlE+wwEgJIkTYdhV3rfZo4jk= github.com/coreos/go-iptables v0.6.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= +github.com/coreos/go-systemd/v22 v22.4.0 h1:y9YHcjnjynCd/DVbg5j9L/33jQM3MxJlbj/zWskzfGU= +github.com/coreos/go-systemd/v22 v22.4.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dblohm7/wingoes v0.0.0-20230426155039-111c8c3b57c8 h1:vtIE3GO4hKplR58aTRx3yLPqAbfWyoyYrE8PXUv0Prw= +github.com/dblohm7/wingoes v0.0.0-20230821191801-fc76608aecf0 h1:/dgKwHVTI0J+A0zd/BHOF2CTn1deN0735cJrb+w2hbE= +github.com/dblohm7/wingoes v0.0.0-20230821191801-fc76608aecf0/go.mod h1:6NCrWM5jRefaG7iN0iMShPalLsljHWBh9v1zxM2f8Xs= github.com/demisto/goxforce v0.0.0-20160322194047-db8357535b1d h1:hmOGJg3cq5XK2aMs7R4kXXVSHqHMaC5hI5fwkX7V2zE= github.com/demisto/goxforce v0.0.0-20160322194047-db8357535b1d/go.mod h1:q72QzdO6OUjwTqnLCFJczIQ7GsBa4ffzkQiQcq6rVTY= github.com/dlclark/regexp2 v1.4.0 h1:F1rxgk7p4uKjwIQxBs9oAXe5CqrXlCduYEJvrF4u93E= @@ -149,8 +152,9 @@ github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6Wezm github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= github.com/gobwas/ws v1.1.0 h1:7RFti/xnNkMJnrK7D1yQ/iCIB5OrrY/54/H930kIbHA= github.com/gobwas/ws v1.1.0/go.mod h1:nzvNcVha5eUziGrbxFCo6qFIojQHjJV5cLYIbezhfL0= -github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= -github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 h1:sQspH8M4niEijh3PFscJRLDnkL547IeP7kpPe3uUhEg= +github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466/go.mod h1:ZiQxhyQ+bbbfxUKVvjfO498oPYvtYhZzycal3G/NHmU= github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA= github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= @@ -177,6 +181,8 @@ github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/nftables v0.1.1-0.20230115205135-9aa6fdf5a28c h1:06RMfw+TMMHtRuUOroMeatRCCgSMWXCJQeABvHU69YQ= +github.com/google/nftables v0.1.1-0.20230115205135-9aa6fdf5a28c/go.mod h1:BVIYo3cdnT4qSylnYqcd5YtmXhr51cJPGtnLBe/uLBU= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= @@ -344,7 +350,6 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= @@ -353,19 +358,18 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.4/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/tailscale/certstore v0.1.1-0.20220316223106-78d6e1c49d8d h1:K3j02b5j2Iw1xoggN9B2DIEkhWGheqFOeDkdJdBrJI8= github.com/tailscale/certstore v0.1.1-0.20220316223106-78d6e1c49d8d/go.mod h1:2P+hpOwd53e7JMX/L4f3VXkv1G+33ES6IWZSrkIeWNs= -github.com/tailscale/golang-x-crypto v0.0.0-20221115211329-17a3db2c30d2 h1:pBpqbsyX9H8c26oPYC2H+232HOdp1gDnCztoKmKWKDA= -github.com/tailscale/golang-x-crypto v0.0.0-20221115211329-17a3db2c30d2/go.mod h1:V2G8jyemEGZWKQ+3xNn4+bOx+FuoXU9Zc5GUsZMthBg= +github.com/tailscale/golang-x-crypto v0.0.0-20230713185742-f0b76a10a08e h1:JyeJF/HuSwvxWtsR1c0oKX1lzaSH5Wh4aX+MgiStaGQ= +github.com/tailscale/golang-x-crypto v0.0.0-20230713185742-f0b76a10a08e/go.mod h1:DjoeCULdP6vTJ/xY+nzzR9LaUHprkbZEpNidX0aqEEk= github.com/tailscale/goupnp v1.0.1-0.20210804011211-c64d0f06ea05 h1:4chzWmimtJPxRs2O36yuGRW3f9SYV+bMTTvMBI0EKio= github.com/tailscale/goupnp v1.0.1-0.20210804011211-c64d0f06ea05/go.mod h1:PdCqy9JzfWMJf1H5UJW2ip33/d4YkoKN0r67yKH1mG8= github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85 h1:zrsUcqrG2uQSPhaUPjUQwozcRdDdSxxqhNgNZ3drZFk= github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0= -github.com/tailscale/wireguard-go v0.0.0-20230410165232-af172621b4dd h1:+fBevMGmDRNi0oWD4SJXmPKLWvIBYX1NroMjo9czjcY= -github.com/tailscale/wireguard-go v0.0.0-20230410165232-af172621b4dd/go.mod h1:QRIcq2+DbdIC5sKh/gcAZhuqu6WT6L6G8/ALPN5wqYw= +github.com/tailscale/wireguard-go v0.0.0-20230824215414-93bd5cbf7fd8 h1:V9kSpiTzFp7OTgJinu/kSJlsI6EfRs8wJgQ+Q+5a8v4= +github.com/tailscale/wireguard-go v0.0.0-20230824215414-93bd5cbf7fd8/go.mod h1:QRIcq2+DbdIC5sKh/gcAZhuqu6WT6L6G8/ALPN5wqYw= github.com/tcnksm/go-httpstat v0.2.0 h1:rP7T5e5U2HfmOBmZzGgGZjBQ5/GluWUylujl0tJ04I0= github.com/tcnksm/go-httpstat v0.2.0/go.mod h1:s3JVJFtQxtBEBC9dwcdTTXS9xFnM3SXAZwPG41aurT8= github.com/tetratelabs/wazero v1.3.1 h1:rnb9FgOEQRLLR8tgoD1mfjNjMhFeWRUk+a4b4j/GpUM= @@ -401,8 +405,8 @@ go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9i go.uber.org/zap v1.18.1/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= go4.org/mem v0.0.0-20220726221520-4f986261bf13 h1:CbZeCBZ0aZj8EfVgnqQcYZgf0lpZ3H9rmp5nkDTAst8= go4.org/mem v0.0.0-20220726221520-4f986261bf13/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= -go4.org/netipx v0.0.0-20230303233057-f1b76eb4bb35 h1:nJAwRlGWZZDOD+6wni9KVUNHMpHko/OnRwsrCYeAzPo= -go4.org/netipx v0.0.0-20230303233057-f1b76eb4bb35/go.mod h1:TQvodOM+hJTioNQJilmLXu08JNb8i+ccq418+KWu1/Y= +go4.org/netipx v0.0.0-20230728180743-ad4cb58a6516 h1:X66ZEoMN2SuaoI/dfZVYobB6E5zjZyyHUMWlCA7MgGE= +go4.org/netipx v0.0.0-20230728180743-ad4cb58a6516/go.mod h1:TQvodOM+hJTioNQJilmLXu08JNb8i+ccq418+KWu1/Y= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190530122614-20be4c3c3ed5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -414,8 +418,8 @@ golang.org/x/crypto v0.0.0-20220208050332-20e1d8d225ab/go.mod h1:IxCIyHEi3zRg3s0 golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= -golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= +golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= +golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/exp/typeparams v0.0.0-20230425010034-47ecfdc1ba53 h1:w/MOPdQ1IoYoDou3L55ZbTx2Nhn7JAhX1BBZor8qChU= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= @@ -424,8 +428,8 @@ golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHl golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= -golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= +golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -486,6 +490,7 @@ golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -619,5 +624,5 @@ modernc.org/z v1.7.3 h1:zDJf6iHjrnB+WRD88stbXokugjyc0/pB91ri1gO6LZY= nhooyr.io/websocket v1.8.7 h1:usjR2uOr/zjjkVMy0lW+PPohFok7PCow5sDjLgX4P4g= nhooyr.io/websocket v1.8.7/go.mod h1:B70DZP8IakI65RVQ51MsWP/8jndNma26DVA/nFSCgW0= software.sslmate.com/src/go-pkcs12 v0.2.0 h1:nlFkj7bTysH6VkC4fGphtjXRbezREPgrHuJG20hBGPE= -tailscale.com v1.44.0 h1:MPos9n30kJvdyfL52045gVFyNg93K+bwgDsr8gqKq2o= -tailscale.com v1.44.0/go.mod h1:+iYwTdeHyVJuNDu42Zafwihq1Uqfh+pW7pRaY1GD328= +tailscale.com v1.50.1 h1:q3lwxT2Y2ezc+FBCMHP8M14cgu1V0JiuLikojdsXuGU= +tailscale.com v1.50.1/go.mod h1:lBw7+Mw2d7rea3kefGjYWN8IJkB5dyaakMNMOinNGDo= diff --git a/vendor/github.com/coreos/go-systemd/v22/LICENSE b/vendor/github.com/coreos/go-systemd/v22/LICENSE new file mode 100644 index 0000000000..37ec93a14f --- /dev/null +++ b/vendor/github.com/coreos/go-systemd/v22/LICENSE @@ -0,0 +1,191 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and +distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright +owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities +that control, are controlled by, or are under common control with that entity. +For the purposes of this definition, "control" means (i) the power, direct or +indirect, to cause the direction or management of such entity, whether by +contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the +outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising +permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including +but not limited to software source code, documentation source, and configuration +files. + +"Object" form shall mean any form resulting from mechanical transformation or +translation of a Source form, including but not limited to compiled object code, +generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made +available under the License, as indicated by a copyright notice that is included +in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that +is based on (or derived from) the Work and for which the editorial revisions, +annotations, elaborations, or other modifications represent, as a whole, an +original work of authorship. For the purposes of this License, Derivative Works +shall not include works that remain separable from, or merely link (or bind by +name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version +of the Work and any modifications or additions to that Work or Derivative Works +thereof, that is intentionally submitted to Licensor for inclusion in the Work +by the copyright owner or by an individual or Legal Entity authorized to submit +on behalf of the copyright owner. For the purposes of this definition, +"submitted" means any form of electronic, verbal, or written communication sent +to the Licensor or its representatives, including but not limited to +communication on electronic mailing lists, source code control systems, and +issue tracking systems that are managed by, or on behalf of, the Licensor for +the purpose of discussing and improving the Work, but excluding communication +that is conspicuously marked or otherwise designated in writing by the copyright +owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf +of whom a Contribution has been received by Licensor and subsequently +incorporated within the Work. + +2. Grant of Copyright License. + +Subject to the terms and conditions of this License, each Contributor hereby +grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, +irrevocable copyright license to reproduce, prepare Derivative Works of, +publicly display, publicly perform, sublicense, and distribute the Work and such +Derivative Works in Source or Object form. + +3. Grant of Patent License. + +Subject to the terms and conditions of this License, each Contributor hereby +grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, +irrevocable (except as stated in this section) patent license to make, have +made, use, offer to sell, sell, import, and otherwise transfer the Work, where +such license applies only to those patent claims licensable by such Contributor +that are necessarily infringed by their Contribution(s) alone or by combination +of their Contribution(s) with the Work to which such Contribution(s) was +submitted. If You institute patent litigation against any entity (including a +cross-claim or counterclaim in a lawsuit) alleging that the Work or a +Contribution incorporated within the Work constitutes direct or contributory +patent infringement, then any patent licenses granted to You under this License +for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. + +You may reproduce and distribute copies of the Work or Derivative Works thereof +in any medium, with or without modifications, and in Source or Object form, +provided that You meet the following conditions: + +You must give any other recipients of the Work or Derivative Works a copy of +this License; and +You must cause any modified files to carry prominent notices stating that You +changed the files; and +You must retain, in the Source form of any Derivative Works that You distribute, +all copyright, patent, trademark, and attribution notices from the Source form +of the Work, excluding those notices that do not pertain to any part of the +Derivative Works; and +If the Work includes a "NOTICE" text file as part of its distribution, then any +Derivative Works that You distribute must include a readable copy of the +attribution notices contained within such NOTICE file, excluding those notices +that do not pertain to any part of the Derivative Works, in at least one of the +following places: within a NOTICE text file distributed as part of the +Derivative Works; within the Source form or documentation, if provided along +with the Derivative Works; or, within a display generated by the Derivative +Works, if and wherever such third-party notices normally appear. The contents of +the NOTICE file are for informational purposes only and do not modify the +License. You may add Your own attribution notices within Derivative Works that +You distribute, alongside or as an addendum to the NOTICE text from the Work, +provided that such additional attribution notices cannot be construed as +modifying the License. +You may add Your own copyright statement to Your modifications and may provide +additional or different license terms and conditions for use, reproduction, or +distribution of Your modifications, or for any such Derivative Works as a whole, +provided Your use, reproduction, and distribution of the Work otherwise complies +with the conditions stated in this License. + +5. Submission of Contributions. + +Unless You explicitly state otherwise, any Contribution intentionally submitted +for inclusion in the Work by You to the Licensor shall be under the terms and +conditions of this License, without any additional terms or conditions. +Notwithstanding the above, nothing herein shall supersede or modify the terms of +any separate license agreement you may have executed with Licensor regarding +such Contributions. + +6. Trademarks. + +This License does not grant permission to use the trade names, trademarks, +service marks, or product names of the Licensor, except as required for +reasonable and customary use in describing the origin of the Work and +reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. + +Unless required by applicable law or agreed to in writing, Licensor provides the +Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, +including, without limitation, any warranties or conditions of TITLE, +NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are +solely responsible for determining the appropriateness of using or +redistributing the Work and assume any risks associated with Your exercise of +permissions under this License. + +8. Limitation of Liability. + +In no event and under no legal theory, whether in tort (including negligence), +contract, or otherwise, unless required by applicable law (such as deliberate +and grossly negligent acts) or agreed to in writing, shall any Contributor be +liable to You for damages, including any direct, indirect, special, incidental, +or consequential damages of any character arising as a result of this License or +out of the use or inability to use the Work (including but not limited to +damages for loss of goodwill, work stoppage, computer failure or malfunction, or +any and all other commercial damages or losses), even if such Contributor has +been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. + +While redistributing the Work or Derivative Works thereof, You may choose to +offer, and charge a fee for, acceptance of support, warranty, indemnity, or +other liability obligations and/or rights consistent with this License. However, +in accepting such obligations, You may act only on Your own behalf and on Your +sole responsibility, not on behalf of any other Contributor, and only if You +agree to indemnify, defend, and hold each Contributor harmless for any liability +incurred by, or claims asserted against, such Contributor by reason of your +accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work + +To apply the Apache License to your work, attach the following boilerplate +notice, with the fields enclosed by brackets "[]" replaced with your own +identifying information. (Don't include the brackets!) The text should be +enclosed in the appropriate comment syntax for the file format. We also +recommend that a file or class name and description of purpose be included on +the same "printed page" as the copyright notice for easier identification within +third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/coreos/go-systemd/v22/NOTICE b/vendor/github.com/coreos/go-systemd/v22/NOTICE new file mode 100644 index 0000000000..23a0ada2fb --- /dev/null +++ b/vendor/github.com/coreos/go-systemd/v22/NOTICE @@ -0,0 +1,5 @@ +CoreOS Project +Copyright 2018 CoreOS, Inc + +This product includes software developed at CoreOS, Inc. +(http://www.coreos.com/). diff --git a/vendor/github.com/coreos/go-systemd/v22/dbus/dbus.go b/vendor/github.com/coreos/go-systemd/v22/dbus/dbus.go new file mode 100644 index 0000000000..147f756fe2 --- /dev/null +++ b/vendor/github.com/coreos/go-systemd/v22/dbus/dbus.go @@ -0,0 +1,266 @@ +// Copyright 2015 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Integration with the systemd D-Bus API. See http://www.freedesktop.org/wiki/Software/systemd/dbus/ +package dbus + +import ( + "context" + "encoding/hex" + "fmt" + "os" + "strconv" + "strings" + "sync" + + "github.com/godbus/dbus/v5" +) + +const ( + alpha = `abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ` + num = `0123456789` + alphanum = alpha + num + signalBuffer = 100 +) + +// needsEscape checks whether a byte in a potential dbus ObjectPath needs to be escaped +func needsEscape(i int, b byte) bool { + // Escape everything that is not a-z-A-Z-0-9 + // Also escape 0-9 if it's the first character + return strings.IndexByte(alphanum, b) == -1 || + (i == 0 && strings.IndexByte(num, b) != -1) +} + +// PathBusEscape sanitizes a constituent string of a dbus ObjectPath using the +// rules that systemd uses for serializing special characters. +func PathBusEscape(path string) string { + // Special case the empty string + if len(path) == 0 { + return "_" + } + n := []byte{} + for i := 0; i < len(path); i++ { + c := path[i] + if needsEscape(i, c) { + e := fmt.Sprintf("_%x", c) + n = append(n, []byte(e)...) + } else { + n = append(n, c) + } + } + return string(n) +} + +// pathBusUnescape is the inverse of PathBusEscape. +func pathBusUnescape(path string) string { + if path == "_" { + return "" + } + n := []byte{} + for i := 0; i < len(path); i++ { + c := path[i] + if c == '_' && i+2 < len(path) { + res, err := hex.DecodeString(path[i+1 : i+3]) + if err == nil { + n = append(n, res...) + } + i += 2 + } else { + n = append(n, c) + } + } + return string(n) +} + +// Conn is a connection to systemd's dbus endpoint. +type Conn struct { + // sysconn/sysobj are only used to call dbus methods + sysconn *dbus.Conn + sysobj dbus.BusObject + + // sigconn/sigobj are only used to receive dbus signals + sigconn *dbus.Conn + sigobj dbus.BusObject + + jobListener struct { + jobs map[dbus.ObjectPath]chan<- string + sync.Mutex + } + subStateSubscriber struct { + updateCh chan<- *SubStateUpdate + errCh chan<- error + sync.Mutex + ignore map[dbus.ObjectPath]int64 + cleanIgnore int64 + } + propertiesSubscriber struct { + updateCh chan<- *PropertiesUpdate + errCh chan<- error + sync.Mutex + } +} + +// Deprecated: use NewWithContext instead. +func New() (*Conn, error) { + return NewWithContext(context.Background()) +} + +// NewWithContext establishes a connection to any available bus and authenticates. +// Callers should call Close() when done with the connection. +func NewWithContext(ctx context.Context) (*Conn, error) { + conn, err := NewSystemConnectionContext(ctx) + if err != nil && os.Geteuid() == 0 { + return NewSystemdConnectionContext(ctx) + } + return conn, err +} + +// Deprecated: use NewSystemConnectionContext instead. +func NewSystemConnection() (*Conn, error) { + return NewSystemConnectionContext(context.Background()) +} + +// NewSystemConnectionContext establishes a connection to the system bus and authenticates. +// Callers should call Close() when done with the connection. +func NewSystemConnectionContext(ctx context.Context) (*Conn, error) { + return NewConnection(func() (*dbus.Conn, error) { + return dbusAuthHelloConnection(ctx, dbus.SystemBusPrivate) + }) +} + +// Deprecated: use NewUserConnectionContext instead. +func NewUserConnection() (*Conn, error) { + return NewUserConnectionContext(context.Background()) +} + +// NewUserConnectionContext establishes a connection to the session bus and +// authenticates. This can be used to connect to systemd user instances. +// Callers should call Close() when done with the connection. +func NewUserConnectionContext(ctx context.Context) (*Conn, error) { + return NewConnection(func() (*dbus.Conn, error) { + return dbusAuthHelloConnection(ctx, dbus.SessionBusPrivate) + }) +} + +// Deprecated: use NewSystemdConnectionContext instead. +func NewSystemdConnection() (*Conn, error) { + return NewSystemdConnectionContext(context.Background()) +} + +// NewSystemdConnectionContext establishes a private, direct connection to systemd. +// This can be used for communicating with systemd without a dbus daemon. +// Callers should call Close() when done with the connection. +func NewSystemdConnectionContext(ctx context.Context) (*Conn, error) { + return NewConnection(func() (*dbus.Conn, error) { + // We skip Hello when talking directly to systemd. + return dbusAuthConnection(ctx, func(opts ...dbus.ConnOption) (*dbus.Conn, error) { + return dbus.Dial("unix:path=/run/systemd/private", opts...) + }) + }) +} + +// Close closes an established connection. +func (c *Conn) Close() { + c.sysconn.Close() + c.sigconn.Close() +} + +// Connected returns whether conn is connected +func (c *Conn) Connected() bool { + return c.sysconn.Connected() && c.sigconn.Connected() +} + +// NewConnection establishes a connection to a bus using a caller-supplied function. +// This allows connecting to remote buses through a user-supplied mechanism. +// The supplied function may be called multiple times, and should return independent connections. +// The returned connection must be fully initialised: the org.freedesktop.DBus.Hello call must have succeeded, +// and any authentication should be handled by the function. +func NewConnection(dialBus func() (*dbus.Conn, error)) (*Conn, error) { + sysconn, err := dialBus() + if err != nil { + return nil, err + } + + sigconn, err := dialBus() + if err != nil { + sysconn.Close() + return nil, err + } + + c := &Conn{ + sysconn: sysconn, + sysobj: systemdObject(sysconn), + sigconn: sigconn, + sigobj: systemdObject(sigconn), + } + + c.subStateSubscriber.ignore = make(map[dbus.ObjectPath]int64) + c.jobListener.jobs = make(map[dbus.ObjectPath]chan<- string) + + // Setup the listeners on jobs so that we can get completions + c.sigconn.BusObject().Call("org.freedesktop.DBus.AddMatch", 0, + "type='signal', interface='org.freedesktop.systemd1.Manager', member='JobRemoved'") + + c.dispatch() + return c, nil +} + +// GetManagerProperty returns the value of a property on the org.freedesktop.systemd1.Manager +// interface. The value is returned in its string representation, as defined at +// https://developer.gnome.org/glib/unstable/gvariant-text.html. +func (c *Conn) GetManagerProperty(prop string) (string, error) { + variant, err := c.sysobj.GetProperty("org.freedesktop.systemd1.Manager." + prop) + if err != nil { + return "", err + } + return variant.String(), nil +} + +func dbusAuthConnection(ctx context.Context, createBus func(opts ...dbus.ConnOption) (*dbus.Conn, error)) (*dbus.Conn, error) { + conn, err := createBus(dbus.WithContext(ctx)) + if err != nil { + return nil, err + } + + // Only use EXTERNAL method, and hardcode the uid (not username) + // to avoid a username lookup (which requires a dynamically linked + // libc) + methods := []dbus.Auth{dbus.AuthExternal(strconv.Itoa(os.Getuid()))} + + err = conn.Auth(methods) + if err != nil { + conn.Close() + return nil, err + } + + return conn, nil +} + +func dbusAuthHelloConnection(ctx context.Context, createBus func(opts ...dbus.ConnOption) (*dbus.Conn, error)) (*dbus.Conn, error) { + conn, err := dbusAuthConnection(ctx, createBus) + if err != nil { + return nil, err + } + + if err = conn.Hello(); err != nil { + conn.Close() + return nil, err + } + + return conn, nil +} + +func systemdObject(conn *dbus.Conn) dbus.BusObject { + return conn.Object("org.freedesktop.systemd1", dbus.ObjectPath("/org/freedesktop/systemd1")) +} diff --git a/vendor/github.com/coreos/go-systemd/v22/dbus/methods.go b/vendor/github.com/coreos/go-systemd/v22/dbus/methods.go new file mode 100644 index 0000000000..074148cb4d --- /dev/null +++ b/vendor/github.com/coreos/go-systemd/v22/dbus/methods.go @@ -0,0 +1,864 @@ +// Copyright 2015, 2018 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dbus + +import ( + "context" + "errors" + "fmt" + "path" + "strconv" + + "github.com/godbus/dbus/v5" +) + +// Who can be used to specify which process to kill in the unit via the KillUnitWithTarget API +type Who string + +const ( + // All sends the signal to all processes in the unit + All Who = "all" + // Main sends the signal to the main process of the unit + Main Who = "main" + // Control sends the signal to the control process of the unit + Control Who = "control" +) + +func (c *Conn) jobComplete(signal *dbus.Signal) { + var id uint32 + var job dbus.ObjectPath + var unit string + var result string + dbus.Store(signal.Body, &id, &job, &unit, &result) + c.jobListener.Lock() + out, ok := c.jobListener.jobs[job] + if ok { + out <- result + delete(c.jobListener.jobs, job) + } + c.jobListener.Unlock() +} + +func (c *Conn) startJob(ctx context.Context, ch chan<- string, job string, args ...interface{}) (int, error) { + if ch != nil { + c.jobListener.Lock() + defer c.jobListener.Unlock() + } + + var p dbus.ObjectPath + err := c.sysobj.CallWithContext(ctx, job, 0, args...).Store(&p) + if err != nil { + return 0, err + } + + if ch != nil { + c.jobListener.jobs[p] = ch + } + + // ignore error since 0 is fine if conversion fails + jobID, _ := strconv.Atoi(path.Base(string(p))) + + return jobID, nil +} + +// Deprecated: use StartUnitContext instead. +func (c *Conn) StartUnit(name string, mode string, ch chan<- string) (int, error) { + return c.StartUnitContext(context.Background(), name, mode, ch) +} + +// StartUnitContext enqueues a start job and depending jobs, if any (unless otherwise +// specified by the mode string). +// +// Takes the unit to activate, plus a mode string. The mode needs to be one of +// replace, fail, isolate, ignore-dependencies, ignore-requirements. If +// "replace" the call will start the unit and its dependencies, possibly +// replacing already queued jobs that conflict with this. If "fail" the call +// will start the unit and its dependencies, but will fail if this would change +// an already queued job. If "isolate" the call will start the unit in question +// and terminate all units that aren't dependencies of it. If +// "ignore-dependencies" it will start a unit but ignore all its dependencies. +// If "ignore-requirements" it will start a unit but only ignore the +// requirement dependencies. It is not recommended to make use of the latter +// two options. +// +// If the provided channel is non-nil, a result string will be sent to it upon +// job completion: one of done, canceled, timeout, failed, dependency, skipped. +// done indicates successful execution of a job. canceled indicates that a job +// has been canceled before it finished execution. timeout indicates that the +// job timeout was reached. failed indicates that the job failed. dependency +// indicates that a job this job has been depending on failed and the job hence +// has been removed too. skipped indicates that a job was skipped because it +// didn't apply to the units current state. +// +// If no error occurs, the ID of the underlying systemd job will be returned. There +// does exist the possibility for no error to be returned, but for the returned job +// ID to be 0. In this case, the actual underlying ID is not 0 and this datapoint +// should not be considered authoritative. +// +// If an error does occur, it will be returned to the user alongside a job ID of 0. +func (c *Conn) StartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) { + return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.StartUnit", name, mode) +} + +// Deprecated: use StopUnitContext instead. +func (c *Conn) StopUnit(name string, mode string, ch chan<- string) (int, error) { + return c.StopUnitContext(context.Background(), name, mode, ch) +} + +// StopUnitContext is similar to StartUnitContext, but stops the specified unit +// rather than starting it. +func (c *Conn) StopUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) { + return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.StopUnit", name, mode) +} + +// Deprecated: use ReloadUnitContext instead. +func (c *Conn) ReloadUnit(name string, mode string, ch chan<- string) (int, error) { + return c.ReloadUnitContext(context.Background(), name, mode, ch) +} + +// ReloadUnitContext reloads a unit. Reloading is done only if the unit +// is already running, and fails otherwise. +func (c *Conn) ReloadUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) { + return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.ReloadUnit", name, mode) +} + +// Deprecated: use RestartUnitContext instead. +func (c *Conn) RestartUnit(name string, mode string, ch chan<- string) (int, error) { + return c.RestartUnitContext(context.Background(), name, mode, ch) +} + +// RestartUnitContext restarts a service. If a service is restarted that isn't +// running it will be started. +func (c *Conn) RestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) { + return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.RestartUnit", name, mode) +} + +// Deprecated: use TryRestartUnitContext instead. +func (c *Conn) TryRestartUnit(name string, mode string, ch chan<- string) (int, error) { + return c.TryRestartUnitContext(context.Background(), name, mode, ch) +} + +// TryRestartUnitContext is like RestartUnitContext, except that a service that +// isn't running is not affected by the restart. +func (c *Conn) TryRestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) { + return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.TryRestartUnit", name, mode) +} + +// Deprecated: use ReloadOrRestartUnitContext instead. +func (c *Conn) ReloadOrRestartUnit(name string, mode string, ch chan<- string) (int, error) { + return c.ReloadOrRestartUnitContext(context.Background(), name, mode, ch) +} + +// ReloadOrRestartUnitContext attempts a reload if the unit supports it and use +// a restart otherwise. +func (c *Conn) ReloadOrRestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) { + return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.ReloadOrRestartUnit", name, mode) +} + +// Deprecated: use ReloadOrTryRestartUnitContext instead. +func (c *Conn) ReloadOrTryRestartUnit(name string, mode string, ch chan<- string) (int, error) { + return c.ReloadOrTryRestartUnitContext(context.Background(), name, mode, ch) +} + +// ReloadOrTryRestartUnitContext attempts a reload if the unit supports it, +// and use a "Try" flavored restart otherwise. +func (c *Conn) ReloadOrTryRestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) { + return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.ReloadOrTryRestartUnit", name, mode) +} + +// Deprecated: use StartTransientUnitContext instead. +func (c *Conn) StartTransientUnit(name string, mode string, properties []Property, ch chan<- string) (int, error) { + return c.StartTransientUnitContext(context.Background(), name, mode, properties, ch) +} + +// StartTransientUnitContext may be used to create and start a transient unit, which +// will be released as soon as it is not running or referenced anymore or the +// system is rebooted. name is the unit name including suffix, and must be +// unique. mode is the same as in StartUnitContext, properties contains properties +// of the unit. +func (c *Conn) StartTransientUnitContext(ctx context.Context, name string, mode string, properties []Property, ch chan<- string) (int, error) { + return c.startJob(ctx, ch, "org.freedesktop.systemd1.Manager.StartTransientUnit", name, mode, properties, make([]PropertyCollection, 0)) +} + +// Deprecated: use KillUnitContext instead. +func (c *Conn) KillUnit(name string, signal int32) { + c.KillUnitContext(context.Background(), name, signal) +} + +// KillUnitContext takes the unit name and a UNIX signal number to send. +// All of the unit's processes are killed. +func (c *Conn) KillUnitContext(ctx context.Context, name string, signal int32) { + c.KillUnitWithTarget(ctx, name, All, signal) +} + +// KillUnitWithTarget is like KillUnitContext, but allows you to specify which +// process in the unit to send the signal to. +func (c *Conn) KillUnitWithTarget(ctx context.Context, name string, target Who, signal int32) error { + return c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.KillUnit", 0, name, string(target), signal).Store() +} + +// Deprecated: use ResetFailedUnitContext instead. +func (c *Conn) ResetFailedUnit(name string) error { + return c.ResetFailedUnitContext(context.Background(), name) +} + +// ResetFailedUnitContext resets the "failed" state of a specific unit. +func (c *Conn) ResetFailedUnitContext(ctx context.Context, name string) error { + return c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ResetFailedUnit", 0, name).Store() +} + +// Deprecated: use SystemStateContext instead. +func (c *Conn) SystemState() (*Property, error) { + return c.SystemStateContext(context.Background()) +} + +// SystemStateContext returns the systemd state. Equivalent to +// systemctl is-system-running. +func (c *Conn) SystemStateContext(ctx context.Context) (*Property, error) { + var err error + var prop dbus.Variant + + obj := c.sysconn.Object("org.freedesktop.systemd1", "/org/freedesktop/systemd1") + err = obj.CallWithContext(ctx, "org.freedesktop.DBus.Properties.Get", 0, "org.freedesktop.systemd1.Manager", "SystemState").Store(&prop) + if err != nil { + return nil, err + } + + return &Property{Name: "SystemState", Value: prop}, nil +} + +// getProperties takes the unit path and returns all of its dbus object properties, for the given dbus interface. +func (c *Conn) getProperties(ctx context.Context, path dbus.ObjectPath, dbusInterface string) (map[string]interface{}, error) { + var err error + var props map[string]dbus.Variant + + if !path.IsValid() { + return nil, fmt.Errorf("invalid unit name: %v", path) + } + + obj := c.sysconn.Object("org.freedesktop.systemd1", path) + err = obj.CallWithContext(ctx, "org.freedesktop.DBus.Properties.GetAll", 0, dbusInterface).Store(&props) + if err != nil { + return nil, err + } + + out := make(map[string]interface{}, len(props)) + for k, v := range props { + out[k] = v.Value() + } + + return out, nil +} + +// Deprecated: use GetUnitPropertiesContext instead. +func (c *Conn) GetUnitProperties(unit string) (map[string]interface{}, error) { + return c.GetUnitPropertiesContext(context.Background(), unit) +} + +// GetUnitPropertiesContext takes the (unescaped) unit name and returns all of +// its dbus object properties. +func (c *Conn) GetUnitPropertiesContext(ctx context.Context, unit string) (map[string]interface{}, error) { + path := unitPath(unit) + return c.getProperties(ctx, path, "org.freedesktop.systemd1.Unit") +} + +// Deprecated: use GetUnitPathPropertiesContext instead. +func (c *Conn) GetUnitPathProperties(path dbus.ObjectPath) (map[string]interface{}, error) { + return c.GetUnitPathPropertiesContext(context.Background(), path) +} + +// GetUnitPathPropertiesContext takes the (escaped) unit path and returns all +// of its dbus object properties. +func (c *Conn) GetUnitPathPropertiesContext(ctx context.Context, path dbus.ObjectPath) (map[string]interface{}, error) { + return c.getProperties(ctx, path, "org.freedesktop.systemd1.Unit") +} + +// Deprecated: use GetAllPropertiesContext instead. +func (c *Conn) GetAllProperties(unit string) (map[string]interface{}, error) { + return c.GetAllPropertiesContext(context.Background(), unit) +} + +// GetAllPropertiesContext takes the (unescaped) unit name and returns all of +// its dbus object properties. +func (c *Conn) GetAllPropertiesContext(ctx context.Context, unit string) (map[string]interface{}, error) { + path := unitPath(unit) + return c.getProperties(ctx, path, "") +} + +func (c *Conn) getProperty(ctx context.Context, unit string, dbusInterface string, propertyName string) (*Property, error) { + var err error + var prop dbus.Variant + + path := unitPath(unit) + if !path.IsValid() { + return nil, errors.New("invalid unit name: " + unit) + } + + obj := c.sysconn.Object("org.freedesktop.systemd1", path) + err = obj.CallWithContext(ctx, "org.freedesktop.DBus.Properties.Get", 0, dbusInterface, propertyName).Store(&prop) + if err != nil { + return nil, err + } + + return &Property{Name: propertyName, Value: prop}, nil +} + +// Deprecated: use GetUnitPropertyContext instead. +func (c *Conn) GetUnitProperty(unit string, propertyName string) (*Property, error) { + return c.GetUnitPropertyContext(context.Background(), unit, propertyName) +} + +// GetUnitPropertyContext takes an (unescaped) unit name, and a property name, +// and returns the property value. +func (c *Conn) GetUnitPropertyContext(ctx context.Context, unit string, propertyName string) (*Property, error) { + return c.getProperty(ctx, unit, "org.freedesktop.systemd1.Unit", propertyName) +} + +// Deprecated: use GetServicePropertyContext instead. +func (c *Conn) GetServiceProperty(service string, propertyName string) (*Property, error) { + return c.GetServicePropertyContext(context.Background(), service, propertyName) +} + +// GetServiceProperty returns property for given service name and property name. +func (c *Conn) GetServicePropertyContext(ctx context.Context, service string, propertyName string) (*Property, error) { + return c.getProperty(ctx, service, "org.freedesktop.systemd1.Service", propertyName) +} + +// Deprecated: use GetUnitTypePropertiesContext instead. +func (c *Conn) GetUnitTypeProperties(unit string, unitType string) (map[string]interface{}, error) { + return c.GetUnitTypePropertiesContext(context.Background(), unit, unitType) +} + +// GetUnitTypePropertiesContext returns the extra properties for a unit, specific to the unit type. +// Valid values for unitType: Service, Socket, Target, Device, Mount, Automount, Snapshot, Timer, Swap, Path, Slice, Scope. +// Returns "dbus.Error: Unknown interface" error if the unitType is not the correct type of the unit. +func (c *Conn) GetUnitTypePropertiesContext(ctx context.Context, unit string, unitType string) (map[string]interface{}, error) { + path := unitPath(unit) + return c.getProperties(ctx, path, "org.freedesktop.systemd1."+unitType) +} + +// Deprecated: use SetUnitPropertiesContext instead. +func (c *Conn) SetUnitProperties(name string, runtime bool, properties ...Property) error { + return c.SetUnitPropertiesContext(context.Background(), name, runtime, properties...) +} + +// SetUnitPropertiesContext may be used to modify certain unit properties at runtime. +// Not all properties may be changed at runtime, but many resource management +// settings (primarily those in systemd.cgroup(5)) may. The changes are applied +// instantly, and stored on disk for future boots, unless runtime is true, in which +// case the settings only apply until the next reboot. name is the name of the unit +// to modify. properties are the settings to set, encoded as an array of property +// name and value pairs. +func (c *Conn) SetUnitPropertiesContext(ctx context.Context, name string, runtime bool, properties ...Property) error { + return c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.SetUnitProperties", 0, name, runtime, properties).Store() +} + +// Deprecated: use GetUnitTypePropertyContext instead. +func (c *Conn) GetUnitTypeProperty(unit string, unitType string, propertyName string) (*Property, error) { + return c.GetUnitTypePropertyContext(context.Background(), unit, unitType, propertyName) +} + +// GetUnitTypePropertyContext takes a property name, a unit name, and a unit type, +// and returns a property value. For valid values of unitType, see GetUnitTypePropertiesContext. +func (c *Conn) GetUnitTypePropertyContext(ctx context.Context, unit string, unitType string, propertyName string) (*Property, error) { + return c.getProperty(ctx, unit, "org.freedesktop.systemd1."+unitType, propertyName) +} + +type UnitStatus struct { + Name string // The primary unit name as string + Description string // The human readable description string + LoadState string // The load state (i.e. whether the unit file has been loaded successfully) + ActiveState string // The active state (i.e. whether the unit is currently started or not) + SubState string // The sub state (a more fine-grained version of the active state that is specific to the unit type, which the active state is not) + Followed string // A unit that is being followed in its state by this unit, if there is any, otherwise the empty string. + Path dbus.ObjectPath // The unit object path + JobId uint32 // If there is a job queued for the job unit the numeric job id, 0 otherwise + JobType string // The job type as string + JobPath dbus.ObjectPath // The job object path +} + +type storeFunc func(retvalues ...interface{}) error + +func (c *Conn) listUnitsInternal(f storeFunc) ([]UnitStatus, error) { + result := make([][]interface{}, 0) + err := f(&result) + if err != nil { + return nil, err + } + + resultInterface := make([]interface{}, len(result)) + for i := range result { + resultInterface[i] = result[i] + } + + status := make([]UnitStatus, len(result)) + statusInterface := make([]interface{}, len(status)) + for i := range status { + statusInterface[i] = &status[i] + } + + err = dbus.Store(resultInterface, statusInterface...) + if err != nil { + return nil, err + } + + return status, nil +} + +// GetUnitByPID returns the unit object path of the unit a process ID +// belongs to. It takes a UNIX PID and returns the object path. The PID must +// refer to an existing system process +func (c *Conn) GetUnitByPID(ctx context.Context, pid uint32) (dbus.ObjectPath, error) { + var result dbus.ObjectPath + + err := c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.GetUnitByPID", 0, pid).Store(&result) + + return result, err +} + +// GetUnitNameByPID returns the name of the unit a process ID belongs to. It +// takes a UNIX PID and returns the object path. The PID must refer to an +// existing system process +func (c *Conn) GetUnitNameByPID(ctx context.Context, pid uint32) (string, error) { + path, err := c.GetUnitByPID(ctx, pid) + if err != nil { + return "", err + } + + return unitName(path), nil +} + +// Deprecated: use ListUnitsContext instead. +func (c *Conn) ListUnits() ([]UnitStatus, error) { + return c.ListUnitsContext(context.Background()) +} + +// ListUnitsContext returns an array with all currently loaded units. Note that +// units may be known by multiple names at the same time, and hence there might +// be more unit names loaded than actual units behind them. +// Also note that a unit is only loaded if it is active and/or enabled. +// Units that are both disabled and inactive will thus not be returned. +func (c *Conn) ListUnitsContext(ctx context.Context) ([]UnitStatus, error) { + return c.listUnitsInternal(c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ListUnits", 0).Store) +} + +// Deprecated: use ListUnitsFilteredContext instead. +func (c *Conn) ListUnitsFiltered(states []string) ([]UnitStatus, error) { + return c.ListUnitsFilteredContext(context.Background(), states) +} + +// ListUnitsFilteredContext returns an array with units filtered by state. +// It takes a list of units' statuses to filter. +func (c *Conn) ListUnitsFilteredContext(ctx context.Context, states []string) ([]UnitStatus, error) { + return c.listUnitsInternal(c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ListUnitsFiltered", 0, states).Store) +} + +// Deprecated: use ListUnitsByPatternsContext instead. +func (c *Conn) ListUnitsByPatterns(states []string, patterns []string) ([]UnitStatus, error) { + return c.ListUnitsByPatternsContext(context.Background(), states, patterns) +} + +// ListUnitsByPatternsContext returns an array with units. +// It takes a list of units' statuses and names to filter. +// Note that units may be known by multiple names at the same time, +// and hence there might be more unit names loaded than actual units behind them. +func (c *Conn) ListUnitsByPatternsContext(ctx context.Context, states []string, patterns []string) ([]UnitStatus, error) { + return c.listUnitsInternal(c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ListUnitsByPatterns", 0, states, patterns).Store) +} + +// Deprecated: use ListUnitsByNamesContext instead. +func (c *Conn) ListUnitsByNames(units []string) ([]UnitStatus, error) { + return c.ListUnitsByNamesContext(context.Background(), units) +} + +// ListUnitsByNamesContext returns an array with units. It takes a list of units' +// names and returns an UnitStatus array. Comparing to ListUnitsByPatternsContext +// method, this method returns statuses even for inactive or non-existing +// units. Input array should contain exact unit names, but not patterns. +// +// Requires systemd v230 or higher. +func (c *Conn) ListUnitsByNamesContext(ctx context.Context, units []string) ([]UnitStatus, error) { + return c.listUnitsInternal(c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ListUnitsByNames", 0, units).Store) +} + +type UnitFile struct { + Path string + Type string +} + +func (c *Conn) listUnitFilesInternal(f storeFunc) ([]UnitFile, error) { + result := make([][]interface{}, 0) + err := f(&result) + if err != nil { + return nil, err + } + + resultInterface := make([]interface{}, len(result)) + for i := range result { + resultInterface[i] = result[i] + } + + files := make([]UnitFile, len(result)) + fileInterface := make([]interface{}, len(files)) + for i := range files { + fileInterface[i] = &files[i] + } + + err = dbus.Store(resultInterface, fileInterface...) + if err != nil { + return nil, err + } + + return files, nil +} + +// Deprecated: use ListUnitFilesContext instead. +func (c *Conn) ListUnitFiles() ([]UnitFile, error) { + return c.ListUnitFilesContext(context.Background()) +} + +// ListUnitFiles returns an array of all available units on disk. +func (c *Conn) ListUnitFilesContext(ctx context.Context) ([]UnitFile, error) { + return c.listUnitFilesInternal(c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ListUnitFiles", 0).Store) +} + +// Deprecated: use ListUnitFilesByPatternsContext instead. +func (c *Conn) ListUnitFilesByPatterns(states []string, patterns []string) ([]UnitFile, error) { + return c.ListUnitFilesByPatternsContext(context.Background(), states, patterns) +} + +// ListUnitFilesByPatternsContext returns an array of all available units on disk matched the patterns. +func (c *Conn) ListUnitFilesByPatternsContext(ctx context.Context, states []string, patterns []string) ([]UnitFile, error) { + return c.listUnitFilesInternal(c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ListUnitFilesByPatterns", 0, states, patterns).Store) +} + +type LinkUnitFileChange EnableUnitFileChange + +// Deprecated: use LinkUnitFilesContext instead. +func (c *Conn) LinkUnitFiles(files []string, runtime bool, force bool) ([]LinkUnitFileChange, error) { + return c.LinkUnitFilesContext(context.Background(), files, runtime, force) +} + +// LinkUnitFilesContext links unit files (that are located outside of the +// usual unit search paths) into the unit search path. +// +// It takes a list of absolute paths to unit files to link and two +// booleans. +// +// The first boolean controls whether the unit shall be +// enabled for runtime only (true, /run), or persistently (false, +// /etc). +// +// The second controls whether symlinks pointing to other units shall +// be replaced if necessary. +// +// This call returns a list of the changes made. The list consists of +// structures with three strings: the type of the change (one of symlink +// or unlink), the file name of the symlink and the destination of the +// symlink. +func (c *Conn) LinkUnitFilesContext(ctx context.Context, files []string, runtime bool, force bool) ([]LinkUnitFileChange, error) { + result := make([][]interface{}, 0) + err := c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.LinkUnitFiles", 0, files, runtime, force).Store(&result) + if err != nil { + return nil, err + } + + resultInterface := make([]interface{}, len(result)) + for i := range result { + resultInterface[i] = result[i] + } + + changes := make([]LinkUnitFileChange, len(result)) + changesInterface := make([]interface{}, len(changes)) + for i := range changes { + changesInterface[i] = &changes[i] + } + + err = dbus.Store(resultInterface, changesInterface...) + if err != nil { + return nil, err + } + + return changes, nil +} + +// Deprecated: use EnableUnitFilesContext instead. +func (c *Conn) EnableUnitFiles(files []string, runtime bool, force bool) (bool, []EnableUnitFileChange, error) { + return c.EnableUnitFilesContext(context.Background(), files, runtime, force) +} + +// EnableUnitFilesContext may be used to enable one or more units in the system +// (by creating symlinks to them in /etc or /run). +// +// It takes a list of unit files to enable (either just file names or full +// absolute paths if the unit files are residing outside the usual unit +// search paths), and two booleans: the first controls whether the unit shall +// be enabled for runtime only (true, /run), or persistently (false, /etc). +// The second one controls whether symlinks pointing to other units shall +// be replaced if necessary. +// +// This call returns one boolean and an array with the changes made. The +// boolean signals whether the unit files contained any enablement +// information (i.e. an [Install]) section. The changes list consists of +// structures with three strings: the type of the change (one of symlink +// or unlink), the file name of the symlink and the destination of the +// symlink. +func (c *Conn) EnableUnitFilesContext(ctx context.Context, files []string, runtime bool, force bool) (bool, []EnableUnitFileChange, error) { + var carries_install_info bool + + result := make([][]interface{}, 0) + err := c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.EnableUnitFiles", 0, files, runtime, force).Store(&carries_install_info, &result) + if err != nil { + return false, nil, err + } + + resultInterface := make([]interface{}, len(result)) + for i := range result { + resultInterface[i] = result[i] + } + + changes := make([]EnableUnitFileChange, len(result)) + changesInterface := make([]interface{}, len(changes)) + for i := range changes { + changesInterface[i] = &changes[i] + } + + err = dbus.Store(resultInterface, changesInterface...) + if err != nil { + return false, nil, err + } + + return carries_install_info, changes, nil +} + +type EnableUnitFileChange struct { + Type string // Type of the change (one of symlink or unlink) + Filename string // File name of the symlink + Destination string // Destination of the symlink +} + +// Deprecated: use DisableUnitFilesContext instead. +func (c *Conn) DisableUnitFiles(files []string, runtime bool) ([]DisableUnitFileChange, error) { + return c.DisableUnitFilesContext(context.Background(), files, runtime) +} + +// DisableUnitFilesContext may be used to disable one or more units in the +// system (by removing symlinks to them from /etc or /run). +// +// It takes a list of unit files to disable (either just file names or full +// absolute paths if the unit files are residing outside the usual unit +// search paths), and one boolean: whether the unit was enabled for runtime +// only (true, /run), or persistently (false, /etc). +// +// This call returns an array with the changes made. The changes list +// consists of structures with three strings: the type of the change (one of +// symlink or unlink), the file name of the symlink and the destination of the +// symlink. +func (c *Conn) DisableUnitFilesContext(ctx context.Context, files []string, runtime bool) ([]DisableUnitFileChange, error) { + result := make([][]interface{}, 0) + err := c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.DisableUnitFiles", 0, files, runtime).Store(&result) + if err != nil { + return nil, err + } + + resultInterface := make([]interface{}, len(result)) + for i := range result { + resultInterface[i] = result[i] + } + + changes := make([]DisableUnitFileChange, len(result)) + changesInterface := make([]interface{}, len(changes)) + for i := range changes { + changesInterface[i] = &changes[i] + } + + err = dbus.Store(resultInterface, changesInterface...) + if err != nil { + return nil, err + } + + return changes, nil +} + +type DisableUnitFileChange struct { + Type string // Type of the change (one of symlink or unlink) + Filename string // File name of the symlink + Destination string // Destination of the symlink +} + +// Deprecated: use MaskUnitFilesContext instead. +func (c *Conn) MaskUnitFiles(files []string, runtime bool, force bool) ([]MaskUnitFileChange, error) { + return c.MaskUnitFilesContext(context.Background(), files, runtime, force) +} + +// MaskUnitFilesContext masks one or more units in the system. +// +// The files argument contains a list of units to mask (either just file names +// or full absolute paths if the unit files are residing outside the usual unit +// search paths). +// +// The runtime argument is used to specify whether the unit was enabled for +// runtime only (true, /run/systemd/..), or persistently (false, +// /etc/systemd/..). +func (c *Conn) MaskUnitFilesContext(ctx context.Context, files []string, runtime bool, force bool) ([]MaskUnitFileChange, error) { + result := make([][]interface{}, 0) + err := c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.MaskUnitFiles", 0, files, runtime, force).Store(&result) + if err != nil { + return nil, err + } + + resultInterface := make([]interface{}, len(result)) + for i := range result { + resultInterface[i] = result[i] + } + + changes := make([]MaskUnitFileChange, len(result)) + changesInterface := make([]interface{}, len(changes)) + for i := range changes { + changesInterface[i] = &changes[i] + } + + err = dbus.Store(resultInterface, changesInterface...) + if err != nil { + return nil, err + } + + return changes, nil +} + +type MaskUnitFileChange struct { + Type string // Type of the change (one of symlink or unlink) + Filename string // File name of the symlink + Destination string // Destination of the symlink +} + +// Deprecated: use UnmaskUnitFilesContext instead. +func (c *Conn) UnmaskUnitFiles(files []string, runtime bool) ([]UnmaskUnitFileChange, error) { + return c.UnmaskUnitFilesContext(context.Background(), files, runtime) +} + +// UnmaskUnitFilesContext unmasks one or more units in the system. +// +// It takes the list of unit files to mask (either just file names or full +// absolute paths if the unit files are residing outside the usual unit search +// paths), and a boolean runtime flag to specify whether the unit was enabled +// for runtime only (true, /run/systemd/..), or persistently (false, +// /etc/systemd/..). +func (c *Conn) UnmaskUnitFilesContext(ctx context.Context, files []string, runtime bool) ([]UnmaskUnitFileChange, error) { + result := make([][]interface{}, 0) + err := c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.UnmaskUnitFiles", 0, files, runtime).Store(&result) + if err != nil { + return nil, err + } + + resultInterface := make([]interface{}, len(result)) + for i := range result { + resultInterface[i] = result[i] + } + + changes := make([]UnmaskUnitFileChange, len(result)) + changesInterface := make([]interface{}, len(changes)) + for i := range changes { + changesInterface[i] = &changes[i] + } + + err = dbus.Store(resultInterface, changesInterface...) + if err != nil { + return nil, err + } + + return changes, nil +} + +type UnmaskUnitFileChange struct { + Type string // Type of the change (one of symlink or unlink) + Filename string // File name of the symlink + Destination string // Destination of the symlink +} + +// Deprecated: use ReloadContext instead. +func (c *Conn) Reload() error { + return c.ReloadContext(context.Background()) +} + +// ReloadContext instructs systemd to scan for and reload unit files. This is +// an equivalent to systemctl daemon-reload. +func (c *Conn) ReloadContext(ctx context.Context) error { + return c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.Reload", 0).Store() +} + +func unitPath(name string) dbus.ObjectPath { + return dbus.ObjectPath("/org/freedesktop/systemd1/unit/" + PathBusEscape(name)) +} + +// unitName returns the unescaped base element of the supplied escaped path. +func unitName(dpath dbus.ObjectPath) string { + return pathBusUnescape(path.Base(string(dpath))) +} + +// JobStatus holds a currently queued job definition. +type JobStatus struct { + Id uint32 // The numeric job id + Unit string // The primary unit name for this job + JobType string // The job type as string + Status string // The job state as string + JobPath dbus.ObjectPath // The job object path + UnitPath dbus.ObjectPath // The unit object path +} + +// Deprecated: use ListJobsContext instead. +func (c *Conn) ListJobs() ([]JobStatus, error) { + return c.ListJobsContext(context.Background()) +} + +// ListJobsContext returns an array with all currently queued jobs. +func (c *Conn) ListJobsContext(ctx context.Context) ([]JobStatus, error) { + return c.listJobsInternal(ctx) +} + +func (c *Conn) listJobsInternal(ctx context.Context) ([]JobStatus, error) { + result := make([][]interface{}, 0) + if err := c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ListJobs", 0).Store(&result); err != nil { + return nil, err + } + + resultInterface := make([]interface{}, len(result)) + for i := range result { + resultInterface[i] = result[i] + } + + status := make([]JobStatus, len(result)) + statusInterface := make([]interface{}, len(status)) + for i := range status { + statusInterface[i] = &status[i] + } + + if err := dbus.Store(resultInterface, statusInterface...); err != nil { + return nil, err + } + + return status, nil +} + +// Freeze the cgroup associated with the unit. +// Note that FreezeUnit and ThawUnit are only supported on systems running with cgroup v2. +func (c *Conn) FreezeUnit(ctx context.Context, unit string) error { + return c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.FreezeUnit", 0, unit).Store() +} + +// Unfreeze the cgroup associated with the unit. +func (c *Conn) ThawUnit(ctx context.Context, unit string) error { + return c.sysobj.CallWithContext(ctx, "org.freedesktop.systemd1.Manager.ThawUnit", 0, unit).Store() +} diff --git a/vendor/github.com/coreos/go-systemd/v22/dbus/properties.go b/vendor/github.com/coreos/go-systemd/v22/dbus/properties.go new file mode 100644 index 0000000000..fb42b62733 --- /dev/null +++ b/vendor/github.com/coreos/go-systemd/v22/dbus/properties.go @@ -0,0 +1,237 @@ +// Copyright 2015 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dbus + +import ( + "github.com/godbus/dbus/v5" +) + +// From the systemd docs: +// +// The properties array of StartTransientUnit() may take many of the settings +// that may also be configured in unit files. Not all parameters are currently +// accepted though, but we plan to cover more properties with future release. +// Currently you may set the Description, Slice and all dependency types of +// units, as well as RemainAfterExit, ExecStart for service units, +// TimeoutStopUSec and PIDs for scope units, and CPUAccounting, CPUShares, +// BlockIOAccounting, BlockIOWeight, BlockIOReadBandwidth, +// BlockIOWriteBandwidth, BlockIODeviceWeight, MemoryAccounting, MemoryLimit, +// DevicePolicy, DeviceAllow for services/scopes/slices. These fields map +// directly to their counterparts in unit files and as normal D-Bus object +// properties. The exception here is the PIDs field of scope units which is +// used for construction of the scope only and specifies the initial PIDs to +// add to the scope object. + +type Property struct { + Name string + Value dbus.Variant +} + +type PropertyCollection struct { + Name string + Properties []Property +} + +type execStart struct { + Path string // the binary path to execute + Args []string // an array with all arguments to pass to the executed command, starting with argument 0 + UncleanIsFailure bool // a boolean whether it should be considered a failure if the process exits uncleanly +} + +// PropExecStart sets the ExecStart service property. The first argument is a +// slice with the binary path to execute followed by the arguments to pass to +// the executed command. See +// http://www.freedesktop.org/software/systemd/man/systemd.service.html#ExecStart= +func PropExecStart(command []string, uncleanIsFailure bool) Property { + execStarts := []execStart{ + { + Path: command[0], + Args: command, + UncleanIsFailure: uncleanIsFailure, + }, + } + + return Property{ + Name: "ExecStart", + Value: dbus.MakeVariant(execStarts), + } +} + +// PropRemainAfterExit sets the RemainAfterExit service property. See +// http://www.freedesktop.org/software/systemd/man/systemd.service.html#RemainAfterExit= +func PropRemainAfterExit(b bool) Property { + return Property{ + Name: "RemainAfterExit", + Value: dbus.MakeVariant(b), + } +} + +// PropType sets the Type service property. See +// http://www.freedesktop.org/software/systemd/man/systemd.service.html#Type= +func PropType(t string) Property { + return Property{ + Name: "Type", + Value: dbus.MakeVariant(t), + } +} + +// PropDescription sets the Description unit property. See +// http://www.freedesktop.org/software/systemd/man/systemd.unit#Description= +func PropDescription(desc string) Property { + return Property{ + Name: "Description", + Value: dbus.MakeVariant(desc), + } +} + +func propDependency(name string, units []string) Property { + return Property{ + Name: name, + Value: dbus.MakeVariant(units), + } +} + +// PropRequires sets the Requires unit property. See +// http://www.freedesktop.org/software/systemd/man/systemd.unit.html#Requires= +func PropRequires(units ...string) Property { + return propDependency("Requires", units) +} + +// PropRequiresOverridable sets the RequiresOverridable unit property. See +// http://www.freedesktop.org/software/systemd/man/systemd.unit.html#RequiresOverridable= +func PropRequiresOverridable(units ...string) Property { + return propDependency("RequiresOverridable", units) +} + +// PropRequisite sets the Requisite unit property. See +// http://www.freedesktop.org/software/systemd/man/systemd.unit.html#Requisite= +func PropRequisite(units ...string) Property { + return propDependency("Requisite", units) +} + +// PropRequisiteOverridable sets the RequisiteOverridable unit property. See +// http://www.freedesktop.org/software/systemd/man/systemd.unit.html#RequisiteOverridable= +func PropRequisiteOverridable(units ...string) Property { + return propDependency("RequisiteOverridable", units) +} + +// PropWants sets the Wants unit property. See +// http://www.freedesktop.org/software/systemd/man/systemd.unit.html#Wants= +func PropWants(units ...string) Property { + return propDependency("Wants", units) +} + +// PropBindsTo sets the BindsTo unit property. See +// http://www.freedesktop.org/software/systemd/man/systemd.unit.html#BindsTo= +func PropBindsTo(units ...string) Property { + return propDependency("BindsTo", units) +} + +// PropRequiredBy sets the RequiredBy unit property. See +// http://www.freedesktop.org/software/systemd/man/systemd.unit.html#RequiredBy= +func PropRequiredBy(units ...string) Property { + return propDependency("RequiredBy", units) +} + +// PropRequiredByOverridable sets the RequiredByOverridable unit property. See +// http://www.freedesktop.org/software/systemd/man/systemd.unit.html#RequiredByOverridable= +func PropRequiredByOverridable(units ...string) Property { + return propDependency("RequiredByOverridable", units) +} + +// PropWantedBy sets the WantedBy unit property. See +// http://www.freedesktop.org/software/systemd/man/systemd.unit.html#WantedBy= +func PropWantedBy(units ...string) Property { + return propDependency("WantedBy", units) +} + +// PropBoundBy sets the BoundBy unit property. See +// http://www.freedesktop.org/software/systemd/main/systemd.unit.html#BoundBy= +func PropBoundBy(units ...string) Property { + return propDependency("BoundBy", units) +} + +// PropConflicts sets the Conflicts unit property. See +// http://www.freedesktop.org/software/systemd/man/systemd.unit.html#Conflicts= +func PropConflicts(units ...string) Property { + return propDependency("Conflicts", units) +} + +// PropConflictedBy sets the ConflictedBy unit property. See +// http://www.freedesktop.org/software/systemd/man/systemd.unit.html#ConflictedBy= +func PropConflictedBy(units ...string) Property { + return propDependency("ConflictedBy", units) +} + +// PropBefore sets the Before unit property. See +// http://www.freedesktop.org/software/systemd/man/systemd.unit.html#Before= +func PropBefore(units ...string) Property { + return propDependency("Before", units) +} + +// PropAfter sets the After unit property. See +// http://www.freedesktop.org/software/systemd/man/systemd.unit.html#After= +func PropAfter(units ...string) Property { + return propDependency("After", units) +} + +// PropOnFailure sets the OnFailure unit property. See +// http://www.freedesktop.org/software/systemd/man/systemd.unit.html#OnFailure= +func PropOnFailure(units ...string) Property { + return propDependency("OnFailure", units) +} + +// PropTriggers sets the Triggers unit property. See +// http://www.freedesktop.org/software/systemd/man/systemd.unit.html#Triggers= +func PropTriggers(units ...string) Property { + return propDependency("Triggers", units) +} + +// PropTriggeredBy sets the TriggeredBy unit property. See +// http://www.freedesktop.org/software/systemd/man/systemd.unit.html#TriggeredBy= +func PropTriggeredBy(units ...string) Property { + return propDependency("TriggeredBy", units) +} + +// PropPropagatesReloadTo sets the PropagatesReloadTo unit property. See +// http://www.freedesktop.org/software/systemd/man/systemd.unit.html#PropagatesReloadTo= +func PropPropagatesReloadTo(units ...string) Property { + return propDependency("PropagatesReloadTo", units) +} + +// PropRequiresMountsFor sets the RequiresMountsFor unit property. See +// http://www.freedesktop.org/software/systemd/man/systemd.unit.html#RequiresMountsFor= +func PropRequiresMountsFor(units ...string) Property { + return propDependency("RequiresMountsFor", units) +} + +// PropSlice sets the Slice unit property. See +// http://www.freedesktop.org/software/systemd/man/systemd.resource-control.html#Slice= +func PropSlice(slice string) Property { + return Property{ + Name: "Slice", + Value: dbus.MakeVariant(slice), + } +} + +// PropPids sets the PIDs field of scope units used in the initial construction +// of the scope only and specifies the initial PIDs to add to the scope object. +// See https://www.freedesktop.org/wiki/Software/systemd/ControlGroupInterface/#properties +func PropPids(pids ...uint32) Property { + return Property{ + Name: "PIDs", + Value: dbus.MakeVariant(pids), + } +} diff --git a/vendor/github.com/coreos/go-systemd/v22/dbus/set.go b/vendor/github.com/coreos/go-systemd/v22/dbus/set.go new file mode 100644 index 0000000000..17c5d48565 --- /dev/null +++ b/vendor/github.com/coreos/go-systemd/v22/dbus/set.go @@ -0,0 +1,47 @@ +// Copyright 2015 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dbus + +type set struct { + data map[string]bool +} + +func (s *set) Add(value string) { + s.data[value] = true +} + +func (s *set) Remove(value string) { + delete(s.data, value) +} + +func (s *set) Contains(value string) (exists bool) { + _, exists = s.data[value] + return +} + +func (s *set) Length() int { + return len(s.data) +} + +func (s *set) Values() (values []string) { + for val := range s.data { + values = append(values, val) + } + return +} + +func newSet() *set { + return &set{make(map[string]bool)} +} diff --git a/vendor/github.com/coreos/go-systemd/v22/dbus/subscription.go b/vendor/github.com/coreos/go-systemd/v22/dbus/subscription.go new file mode 100644 index 0000000000..7e370fea21 --- /dev/null +++ b/vendor/github.com/coreos/go-systemd/v22/dbus/subscription.go @@ -0,0 +1,333 @@ +// Copyright 2015 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dbus + +import ( + "errors" + "log" + "time" + + "github.com/godbus/dbus/v5" +) + +const ( + cleanIgnoreInterval = int64(10 * time.Second) + ignoreInterval = int64(30 * time.Millisecond) +) + +// Subscribe sets up this connection to subscribe to all systemd dbus events. +// This is required before calling SubscribeUnits. When the connection closes +// systemd will automatically stop sending signals so there is no need to +// explicitly call Unsubscribe(). +func (c *Conn) Subscribe() error { + c.sigconn.BusObject().Call("org.freedesktop.DBus.AddMatch", 0, + "type='signal',interface='org.freedesktop.systemd1.Manager',member='UnitNew'") + c.sigconn.BusObject().Call("org.freedesktop.DBus.AddMatch", 0, + "type='signal',interface='org.freedesktop.DBus.Properties',member='PropertiesChanged'") + + return c.sigobj.Call("org.freedesktop.systemd1.Manager.Subscribe", 0).Store() +} + +// Unsubscribe this connection from systemd dbus events. +func (c *Conn) Unsubscribe() error { + return c.sigobj.Call("org.freedesktop.systemd1.Manager.Unsubscribe", 0).Store() +} + +func (c *Conn) dispatch() { + ch := make(chan *dbus.Signal, signalBuffer) + + c.sigconn.Signal(ch) + + go func() { + for { + signal, ok := <-ch + if !ok { + return + } + + if signal.Name == "org.freedesktop.systemd1.Manager.JobRemoved" { + c.jobComplete(signal) + } + + if c.subStateSubscriber.updateCh == nil && + c.propertiesSubscriber.updateCh == nil { + continue + } + + var unitPath dbus.ObjectPath + switch signal.Name { + case "org.freedesktop.systemd1.Manager.JobRemoved": + unitName := signal.Body[2].(string) + c.sysobj.Call("org.freedesktop.systemd1.Manager.GetUnit", 0, unitName).Store(&unitPath) + case "org.freedesktop.systemd1.Manager.UnitNew": + unitPath = signal.Body[1].(dbus.ObjectPath) + case "org.freedesktop.DBus.Properties.PropertiesChanged": + if signal.Body[0].(string) == "org.freedesktop.systemd1.Unit" { + unitPath = signal.Path + + if len(signal.Body) >= 2 { + if changed, ok := signal.Body[1].(map[string]dbus.Variant); ok { + c.sendPropertiesUpdate(unitPath, changed) + } + } + } + } + + if unitPath == dbus.ObjectPath("") { + continue + } + + c.sendSubStateUpdate(unitPath) + } + }() +} + +// SubscribeUnits returns two unbuffered channels which will receive all changed units every +// interval. Deleted units are sent as nil. +func (c *Conn) SubscribeUnits(interval time.Duration) (<-chan map[string]*UnitStatus, <-chan error) { + return c.SubscribeUnitsCustom(interval, 0, func(u1, u2 *UnitStatus) bool { return *u1 != *u2 }, nil) +} + +// SubscribeUnitsCustom is like SubscribeUnits but lets you specify the buffer +// size of the channels, the comparison function for detecting changes and a filter +// function for cutting down on the noise that your channel receives. +func (c *Conn) SubscribeUnitsCustom(interval time.Duration, buffer int, isChanged func(*UnitStatus, *UnitStatus) bool, filterUnit func(string) bool) (<-chan map[string]*UnitStatus, <-chan error) { + old := make(map[string]*UnitStatus) + statusChan := make(chan map[string]*UnitStatus, buffer) + errChan := make(chan error, buffer) + + go func() { + for { + timerChan := time.After(interval) + + units, err := c.ListUnits() + if err == nil { + cur := make(map[string]*UnitStatus) + for i := range units { + if filterUnit != nil && filterUnit(units[i].Name) { + continue + } + cur[units[i].Name] = &units[i] + } + + // add all new or changed units + changed := make(map[string]*UnitStatus) + for n, u := range cur { + if oldU, ok := old[n]; !ok || isChanged(oldU, u) { + changed[n] = u + } + delete(old, n) + } + + // add all deleted units + for oldN := range old { + changed[oldN] = nil + } + + old = cur + + if len(changed) != 0 { + statusChan <- changed + } + } else { + errChan <- err + } + + <-timerChan + } + }() + + return statusChan, errChan +} + +type SubStateUpdate struct { + UnitName string + SubState string +} + +// SetSubStateSubscriber writes to updateCh when any unit's substate changes. +// Although this writes to updateCh on every state change, the reported state +// may be more recent than the change that generated it (due to an unavoidable +// race in the systemd dbus interface). That is, this method provides a good +// way to keep a current view of all units' states, but is not guaranteed to +// show every state transition they go through. Furthermore, state changes +// will only be written to the channel with non-blocking writes. If updateCh +// is full, it attempts to write an error to errCh; if errCh is full, the error +// passes silently. +func (c *Conn) SetSubStateSubscriber(updateCh chan<- *SubStateUpdate, errCh chan<- error) { + if c == nil { + msg := "nil receiver" + select { + case errCh <- errors.New(msg): + default: + log.Printf("full error channel while reporting: %s\n", msg) + } + return + } + + c.subStateSubscriber.Lock() + defer c.subStateSubscriber.Unlock() + c.subStateSubscriber.updateCh = updateCh + c.subStateSubscriber.errCh = errCh +} + +func (c *Conn) sendSubStateUpdate(unitPath dbus.ObjectPath) { + c.subStateSubscriber.Lock() + defer c.subStateSubscriber.Unlock() + + if c.subStateSubscriber.updateCh == nil { + return + } + + isIgnored := c.shouldIgnore(unitPath) + defer c.cleanIgnore() + if isIgnored { + return + } + + info, err := c.GetUnitPathProperties(unitPath) + if err != nil { + select { + case c.subStateSubscriber.errCh <- err: + default: + log.Printf("full error channel while reporting: %s\n", err) + } + return + } + defer c.updateIgnore(unitPath, info) + + name, ok := info["Id"].(string) + if !ok { + msg := "failed to cast info.Id" + select { + case c.subStateSubscriber.errCh <- errors.New(msg): + default: + log.Printf("full error channel while reporting: %s\n", err) + } + return + } + substate, ok := info["SubState"].(string) + if !ok { + msg := "failed to cast info.SubState" + select { + case c.subStateSubscriber.errCh <- errors.New(msg): + default: + log.Printf("full error channel while reporting: %s\n", msg) + } + return + } + + update := &SubStateUpdate{name, substate} + select { + case c.subStateSubscriber.updateCh <- update: + default: + msg := "update channel is full" + select { + case c.subStateSubscriber.errCh <- errors.New(msg): + default: + log.Printf("full error channel while reporting: %s\n", msg) + } + return + } +} + +// The ignore functions work around a wart in the systemd dbus interface. +// Requesting the properties of an unloaded unit will cause systemd to send a +// pair of UnitNew/UnitRemoved signals. Because we need to get a unit's +// properties on UnitNew (as that's the only indication of a new unit coming up +// for the first time), we would enter an infinite loop if we did not attempt +// to detect and ignore these spurious signals. The signal themselves are +// indistinguishable from relevant ones, so we (somewhat hackishly) ignore an +// unloaded unit's signals for a short time after requesting its properties. +// This means that we will miss e.g. a transient unit being restarted +// *immediately* upon failure and also a transient unit being started +// immediately after requesting its status (with systemctl status, for example, +// because this causes a UnitNew signal to be sent which then causes us to fetch +// the properties). + +func (c *Conn) shouldIgnore(path dbus.ObjectPath) bool { + t, ok := c.subStateSubscriber.ignore[path] + return ok && t >= time.Now().UnixNano() +} + +func (c *Conn) updateIgnore(path dbus.ObjectPath, info map[string]interface{}) { + loadState, ok := info["LoadState"].(string) + if !ok { + return + } + + // unit is unloaded - it will trigger bad systemd dbus behavior + if loadState == "not-found" { + c.subStateSubscriber.ignore[path] = time.Now().UnixNano() + ignoreInterval + } +} + +// without this, ignore would grow unboundedly over time +func (c *Conn) cleanIgnore() { + now := time.Now().UnixNano() + if c.subStateSubscriber.cleanIgnore < now { + c.subStateSubscriber.cleanIgnore = now + cleanIgnoreInterval + + for p, t := range c.subStateSubscriber.ignore { + if t < now { + delete(c.subStateSubscriber.ignore, p) + } + } + } +} + +// PropertiesUpdate holds a map of a unit's changed properties +type PropertiesUpdate struct { + UnitName string + Changed map[string]dbus.Variant +} + +// SetPropertiesSubscriber writes to updateCh when any unit's properties +// change. Every property change reported by systemd will be sent; that is, no +// transitions will be "missed" (as they might be with SetSubStateSubscriber). +// However, state changes will only be written to the channel with non-blocking +// writes. If updateCh is full, it attempts to write an error to errCh; if +// errCh is full, the error passes silently. +func (c *Conn) SetPropertiesSubscriber(updateCh chan<- *PropertiesUpdate, errCh chan<- error) { + c.propertiesSubscriber.Lock() + defer c.propertiesSubscriber.Unlock() + c.propertiesSubscriber.updateCh = updateCh + c.propertiesSubscriber.errCh = errCh +} + +// we don't need to worry about shouldIgnore() here because +// sendPropertiesUpdate doesn't call GetProperties() +func (c *Conn) sendPropertiesUpdate(unitPath dbus.ObjectPath, changedProps map[string]dbus.Variant) { + c.propertiesSubscriber.Lock() + defer c.propertiesSubscriber.Unlock() + + if c.propertiesSubscriber.updateCh == nil { + return + } + + update := &PropertiesUpdate{unitName(unitPath), changedProps} + + select { + case c.propertiesSubscriber.updateCh <- update: + default: + msg := "update channel is full" + select { + case c.propertiesSubscriber.errCh <- errors.New(msg): + default: + log.Printf("full error channel while reporting: %s\n", msg) + } + return + } +} diff --git a/vendor/github.com/coreos/go-systemd/v22/dbus/subscription_set.go b/vendor/github.com/coreos/go-systemd/v22/dbus/subscription_set.go new file mode 100644 index 0000000000..5b408d5847 --- /dev/null +++ b/vendor/github.com/coreos/go-systemd/v22/dbus/subscription_set.go @@ -0,0 +1,57 @@ +// Copyright 2015 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dbus + +import ( + "time" +) + +// SubscriptionSet returns a subscription set which is like conn.Subscribe but +// can filter to only return events for a set of units. +type SubscriptionSet struct { + *set + conn *Conn +} + +func (s *SubscriptionSet) filter(unit string) bool { + return !s.Contains(unit) +} + +// Subscribe starts listening for dbus events for all of the units in the set. +// Returns channels identical to conn.SubscribeUnits. +func (s *SubscriptionSet) Subscribe() (<-chan map[string]*UnitStatus, <-chan error) { + // TODO: Make fully evented by using systemd 209 with properties changed values + return s.conn.SubscribeUnitsCustom(time.Second, 0, + mismatchUnitStatus, + func(unit string) bool { return s.filter(unit) }, + ) +} + +// NewSubscriptionSet returns a new subscription set. +func (conn *Conn) NewSubscriptionSet() *SubscriptionSet { + return &SubscriptionSet{newSet(), conn} +} + +// mismatchUnitStatus returns true if the provided UnitStatus objects +// are not equivalent. false is returned if the objects are equivalent. +// Only the Name, Description and state-related fields are used in +// the comparison. +func mismatchUnitStatus(u1, u2 *UnitStatus) bool { + return u1.Name != u2.Name || + u1.Description != u2.Description || + u1.LoadState != u2.LoadState || + u1.ActiveState != u2.ActiveState || + u1.SubState != u2.SubState +} diff --git a/vendor/github.com/dblohm7/wingoes/.gitignore b/vendor/github.com/dblohm7/wingoes/.gitignore new file mode 100644 index 0000000000..d27e8563de --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/.gitignore @@ -0,0 +1,19 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Vim +.*.swo +.*.swp + +# Dependency directories (remove the comment below to include it) +# vendor/ diff --git a/vendor/github.com/dblohm7/wingoes/LICENSE b/vendor/github.com/dblohm7/wingoes/LICENSE new file mode 100644 index 0000000000..22e47c7e6b --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2022, Tailscale Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/dblohm7/wingoes/README.md b/vendor/github.com/dblohm7/wingoes/README.md new file mode 100644 index 0000000000..794e35d943 --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/README.md @@ -0,0 +1,2 @@ +# wingoes, an opinionated library for writing Win32 programs in Go + diff --git a/vendor/github.com/dblohm7/wingoes/com/api.go b/vendor/github.com/dblohm7/wingoes/com/api.go new file mode 100644 index 0000000000..9f793b99e2 --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/com/api.go @@ -0,0 +1,140 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows + +package com + +import ( + "unsafe" + + "github.com/dblohm7/wingoes" +) + +// MustGetAppID parses s, a string containing an app ID and returns a pointer to the +// parsed AppID. s must be specified in the format "{XXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX}". +// If there is an error parsing s, MustGetAppID panics. +func MustGetAppID(s string) *AppID { + return (*AppID)(unsafe.Pointer(wingoes.MustGetGUID(s))) +} + +// MustGetCLSID parses s, a string containing a CLSID and returns a pointer to the +// parsed CLSID. s must be specified in the format "{XXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX}". +// If there is an error parsing s, MustGetCLSID panics. +func MustGetCLSID(s string) *CLSID { + return (*CLSID)(unsafe.Pointer(wingoes.MustGetGUID(s))) +} + +// MustGetIID parses s, a string containing an IID and returns a pointer to the +// parsed IID. s must be specified in the format "{XXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX}". +// If there is an error parsing s, MustGetIID panics. +func MustGetIID(s string) *IID { + return (*IID)(unsafe.Pointer(wingoes.MustGetGUID(s))) +} + +// MustGetServiceID parses s, a string containing a service ID and returns a pointer to the +// parsed ServiceID. s must be specified in the format "{XXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX}". +// If there is an error parsing s, MustGetServiceID panics. +func MustGetServiceID(s string) *ServiceID { + return (*ServiceID)(unsafe.Pointer(wingoes.MustGetGUID(s))) +} + +func getCurrentApartmentInfo() (aptInfo, error) { + var info aptInfo + hr := coGetApartmentType(&info.apt, &info.qualifier) + if err := wingoes.ErrorFromHRESULT(hr); err.Failed() { + return info, err + } + + return info, nil +} + +// aptChecker is a function that applies an arbitrary predicate to an OS thread's +// apartment information, returning true if the input satisifes that predicate. +type aptChecker func(*aptInfo) bool + +// checkCurrentApartment obtains information about the COM apartment that the +// current OS thread resides in, and then passes that information to chk, +// which evaluates that information and determines the return value. +func checkCurrentApartment(chk aptChecker) bool { + info, err := getCurrentApartmentInfo() + if err != nil { + return false + } + + return chk(&info) +} + +// AssertCurrentOSThreadSTA checks if the current OS thread resides in a +// single-threaded apartment, and if not, panics. +func AssertCurrentOSThreadSTA() { + if IsCurrentOSThreadSTA() { + return + } + panic("current OS thread does not reside in a single-threaded apartment") +} + +// IsCurrentOSThreadSTA checks if the current OS thread resides in a +// single-threaded apartment and returns true if so. +func IsCurrentOSThreadSTA() bool { + chk := func(i *aptInfo) bool { + return i.apt == coAPTTYPE_STA || i.apt == coAPTTYPE_MAINSTA + } + + return checkCurrentApartment(chk) +} + +// AssertCurrentOSThreadMTA checks if the current OS thread resides in the +// multi-threaded apartment, and if not, panics. +func AssertCurrentOSThreadMTA() { + if IsCurrentOSThreadMTA() { + return + } + panic("current OS thread does not reside in the multi-threaded apartment") +} + +// IsCurrentOSThreadMTA checks if the current OS thread resides in the +// multi-threaded apartment and returns true if so. +func IsCurrentOSThreadMTA() bool { + chk := func(i *aptInfo) bool { + return i.apt == coAPTTYPE_MTA + } + + return checkCurrentApartment(chk) +} + +// createInstanceWithCLSCTX creates a new garbage-collected COM object of type T +// using class clsid. clsctx determines the acceptable location for hosting the +// COM object (in-process, local but out-of-process, or remote). +func createInstanceWithCLSCTX[T Object](clsid *CLSID, clsctx coCLSCTX) (T, error) { + var t T + + iid := t.IID() + ppunk := NewABIReceiver() + + hr := coCreateInstance( + clsid, + nil, + clsctx, + iid, + ppunk, + ) + if err := wingoes.ErrorFromHRESULT(hr); err.Failed() { + return t, err + } + + return t.Make(ppunk).(T), nil +} + +// CreateInstance instantiates a new in-process COM object of type T +// using class clsid. +func CreateInstance[T Object](clsid *CLSID) (T, error) { + return createInstanceWithCLSCTX[T](clsid, coCLSCTX_INPROC_SERVER) +} + +// CreateInstance instantiates a new local, out-of-process COM object of type T +// using class clsid. +func CreateOutOfProcessInstance[T Object](clsid *CLSID) (T, error) { + return createInstanceWithCLSCTX[T](clsid, coCLSCTX_LOCAL_SERVER) +} diff --git a/vendor/github.com/dblohm7/wingoes/com/automation/automation.go b/vendor/github.com/dblohm7/wingoes/com/automation/automation.go new file mode 100644 index 0000000000..3689155af3 --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/com/automation/automation.go @@ -0,0 +1,7 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +// Package automation provides essential types for interacting with COM Automation (IDispatch). +package automation diff --git a/vendor/github.com/dblohm7/wingoes/com/automation/mksyscall.go b/vendor/github.com/dblohm7/wingoes/com/automation/mksyscall.go new file mode 100644 index 0000000000..38c3091aff --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/com/automation/mksyscall.go @@ -0,0 +1,14 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package automation + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go +//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go + +//sys sysAllocString(str *uint16) (ret BSTR) = oleaut32.SysAllocString +//sys sysAllocStringLen(str *uint16, strLen uint32) (ret BSTR) = oleaut32.SysAllocStringLen +//sys sysFreeString(bstr BSTR) = oleaut32.SysFreeString +//sys sysStringLen(bstr BSTR) (ret uint32) = oleaut32.SysStringLen diff --git a/vendor/github.com/dblohm7/wingoes/com/automation/types.go b/vendor/github.com/dblohm7/wingoes/com/automation/types.go new file mode 100644 index 0000000000..bf8c670197 --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/com/automation/types.go @@ -0,0 +1,86 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package automation + +import ( + "unsafe" + + "golang.org/x/sys/windows" +) + +// BSTR is the string format used by COM Automation. They are not garbage +// collected and must be explicitly closed when no longer needed. +type BSTR uintptr + +// NewBSTR creates a new BSTR from string s. +func NewBSTR(s string) BSTR { + buf, err := windows.UTF16FromString(s) + if err != nil { + return 0 + } + return NewBSTRFromUTF16(buf) +} + +// NewBSTR creates a new BSTR from slice us, which contains UTF-16 code units. +func NewBSTRFromUTF16(us []uint16) BSTR { + return sysAllocStringLen(&us[0], uint32(len(us))) +} + +// NewBSTR creates a new BSTR from up, a C-style string pointer to UTF-16 code units. +func NewBSTRFromUTF16Ptr(up *uint16) BSTR { + return sysAllocString(up) +} + +// Len returns the length of bs in code units. +func (bs *BSTR) Len() uint32 { + return sysStringLen(*bs) +} + +// String returns the contents of bs as a Go string. +func (bs *BSTR) String() string { + return windows.UTF16ToString(bs.toUTF16()) +} + +// toUTF16 is unsafe for general use because it returns a pointer that is +// not managed by the Go GC. +func (bs *BSTR) toUTF16() []uint16 { + return unsafe.Slice(bs.toUTF16Ptr(), bs.Len()) +} + +// ToUTF16 returns the contents of bs as a slice of UTF-16 code units. +func (bs *BSTR) ToUTF16() []uint16 { + return append([]uint16{}, bs.toUTF16()...) +} + +// toUTF16Ptr is unsafe for general use because it returns a pointer that is +// not managed by the Go GC. +func (bs *BSTR) toUTF16Ptr() *uint16 { + return (*uint16)(unsafe.Pointer(*bs)) +} + +// ToUTF16 returns the contents of bs as C-style string pointer to UTF-16 code units. +func (bs *BSTR) ToUTF16Ptr() *uint16 { + slc := bs.ToUTF16() + return &slc[0] +} + +// Clone creates a clone of bs whose lifetime becomes independent of the original. +// It must be explicitly closed when no longer needed. +func (bs *BSTR) Clone() BSTR { + return sysAllocStringLen(bs.toUTF16Ptr(), bs.Len()) +} + +// IsNil returns true if bs holds a nil value. +func (bs *BSTR) IsNil() bool { + return *bs == 0 +} + +// Close frees bs. +func (bs *BSTR) Close() error { + sysFreeString(*bs) + *bs = 0 + return nil +} diff --git a/vendor/github.com/dblohm7/wingoes/com/automation/zsyscall_windows.go b/vendor/github.com/dblohm7/wingoes/com/automation/zsyscall_windows.go new file mode 100644 index 0000000000..266f58d1b4 --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/com/automation/zsyscall_windows.go @@ -0,0 +1,70 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package automation + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modoleaut32 = windows.NewLazySystemDLL("oleaut32.dll") + + procSysAllocString = modoleaut32.NewProc("SysAllocString") + procSysAllocStringLen = modoleaut32.NewProc("SysAllocStringLen") + procSysFreeString = modoleaut32.NewProc("SysFreeString") + procSysStringLen = modoleaut32.NewProc("SysStringLen") +) + +func sysAllocString(str *uint16) (ret BSTR) { + r0, _, _ := syscall.Syscall(procSysAllocString.Addr(), 1, uintptr(unsafe.Pointer(str)), 0, 0) + ret = BSTR(r0) + return +} + +func sysAllocStringLen(str *uint16, strLen uint32) (ret BSTR) { + r0, _, _ := syscall.Syscall(procSysAllocStringLen.Addr(), 2, uintptr(unsafe.Pointer(str)), uintptr(strLen), 0) + ret = BSTR(r0) + return +} + +func sysFreeString(bstr BSTR) { + syscall.Syscall(procSysFreeString.Addr(), 1, uintptr(bstr), 0, 0) + return +} + +func sysStringLen(bstr BSTR) (ret uint32) { + r0, _, _ := syscall.Syscall(procSysStringLen.Addr(), 1, uintptr(bstr), 0, 0) + ret = uint32(r0) + return +} diff --git a/vendor/github.com/dblohm7/wingoes/com/com.go b/vendor/github.com/dblohm7/wingoes/com/com.go new file mode 100644 index 0000000000..33f1071d7c --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/com/com.go @@ -0,0 +1,9 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows + +// Package com provides an idiomatic foundation for instantiating and invoking +// COM objects. +package com diff --git a/vendor/github.com/dblohm7/wingoes/com/globalopts.go b/vendor/github.com/dblohm7/wingoes/com/globalopts.go new file mode 100644 index 0000000000..be48fdb48a --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/com/globalopts.go @@ -0,0 +1,120 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows + +package com + +import ( + "runtime" + "syscall" + "unsafe" + + "github.com/dblohm7/wingoes" +) + +var ( + CLSID_GlobalOptions = &CLSID{0x0000034B, 0x0000, 0x0000, [8]byte{0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46}} +) + +var ( + IID_IGlobalOptions = &IID{0x0000015B, 0x0000, 0x0000, [8]byte{0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46}} +) + +type GLOBALOPT_PROPERTIES int32 + +const ( + COMGLB_EXCEPTION_HANDLING = GLOBALOPT_PROPERTIES(1) + COMGLB_APPID = GLOBALOPT_PROPERTIES(2) + COMGLB_RPC_THREADPOOL_SETTING = GLOBALOPT_PROPERTIES(3) + COMGLB_RO_SETTINGS = GLOBALOPT_PROPERTIES(4) + COMGLB_UNMARSHALING_POLICY = GLOBALOPT_PROPERTIES(5) +) + +const ( + COMGLB_EXCEPTION_HANDLE = 0 + COMGLB_EXCEPTION_DONOT_HANDLE_FATAL = 1 + COMGLB_EXCEPTION_DONOT_HANDLE = 1 + COMGLB_EXCEPTION_DONOT_HANDLE_ANY = 2 +) + +// IGlobalOptionsABI represents the COM ABI for the IGlobalOptions interface. +type IGlobalOptionsABI struct { + IUnknownABI +} + +// GlobalOptions is the COM object used for setting global configuration settings +// on the COM runtime. It must be called after COM runtime security has been +// initialized, but before anything else "significant" is done using COM. +type GlobalOptions struct { + GenericObject[IGlobalOptionsABI] +} + +func (abi *IGlobalOptionsABI) Set(prop GLOBALOPT_PROPERTIES, value uintptr) error { + method := unsafe.Slice(abi.Vtbl, 5)[3] + + rc, _, _ := syscall.Syscall( + method, + 3, + uintptr(unsafe.Pointer(abi)), + uintptr(prop), + value, + ) + if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(rc)); e.Failed() { + return e + } + + return nil +} + +func (abi *IGlobalOptionsABI) Query(prop GLOBALOPT_PROPERTIES) (uintptr, error) { + var result uintptr + method := unsafe.Slice(abi.Vtbl, 5)[4] + + rc, _, _ := syscall.Syscall( + method, + 3, + uintptr(unsafe.Pointer(abi)), + uintptr(prop), + uintptr(unsafe.Pointer(&result)), + ) + if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(rc)); e.Failed() { + return 0, e + } + + return result, nil +} + +func (o GlobalOptions) IID() *IID { + return IID_IGlobalOptions +} + +func (o GlobalOptions) Make(r ABIReceiver) any { + if r == nil { + return GlobalOptions{} + } + + runtime.SetFinalizer(r, ReleaseABI) + + pp := (**IGlobalOptionsABI)(unsafe.Pointer(r)) + return GlobalOptions{GenericObject[IGlobalOptionsABI]{Pp: pp}} +} + +// UnsafeUnwrap returns the underlying IGlobalOptionsABI of the object. As the +// name implies, this is unsafe -- you had better know what you are doing! +func (o GlobalOptions) UnsafeUnwrap() *IGlobalOptionsABI { + return *(o.Pp) +} + +// Set sets the global property prop to value. +func (o GlobalOptions) Set(prop GLOBALOPT_PROPERTIES, value uintptr) error { + p := *(o.Pp) + return p.Set(prop, value) +} + +// Query returns the value of global property prop. +func (o GlobalOptions) Query(prop GLOBALOPT_PROPERTIES) (uintptr, error) { + p := *(o.Pp) + return p.Query(prop) +} diff --git a/vendor/github.com/dblohm7/wingoes/com/interface.go b/vendor/github.com/dblohm7/wingoes/com/interface.go new file mode 100644 index 0000000000..522e5c8bcf --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/com/interface.go @@ -0,0 +1,112 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows + +package com + +import ( + "syscall" + "unsafe" + + "github.com/dblohm7/wingoes" +) + +// IUnknown is the base COM interface. +type IUnknown interface { + QueryInterface(iid *IID) (IUnknown, error) + AddRef() int32 + Release() int32 +} + +// This is a sentinel that indicates that a struct implements the COM ABI. +// Only IUnknownABI should implement this. +type hasVTable interface { + vtable() *uintptr +} + +// IUnknownABI describes the ABI of the IUnknown interface (ie, a vtable). +type IUnknownABI struct { + Vtbl *uintptr +} + +func (abi IUnknownABI) vtable() *uintptr { + return abi.Vtbl +} + +// ABI is a type constraint allowing the COM ABI, or any struct that embeds it. +type ABI interface { + hasVTable +} + +// PUnknown is a type constraint for types that both implement IUnknown and +// are also pointers to a COM ABI. +type PUnknown[A ABI] interface { + IUnknown + *A +} + +// ABIReceiver is the type that receives COM interface pointers from COM +// method outparams. +type ABIReceiver **IUnknownABI + +// NewABIReceiver instantiates a new ABIReceiver. +func NewABIReceiver() ABIReceiver { + return ABIReceiver(new(*IUnknownABI)) +} + +// ReleaseABI releases a COM object. Finalizers must always invoke this function +// when destroying COM interfaces. +func ReleaseABI(p **IUnknownABI) { + (*p).Release() +} + +// QueryInterface implements the QueryInterface call for a COM interface pointer. +// iid is the desired interface ID. +func (abi *IUnknownABI) QueryInterface(iid *IID) (IUnknown, error) { + var punk *IUnknownABI + + r, _, _ := syscall.Syscall( + *(abi.Vtbl), + 3, + uintptr(unsafe.Pointer(abi)), + uintptr(unsafe.Pointer(iid)), + uintptr(unsafe.Pointer(&punk)), + ) + if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(r)); e.Failed() { + return nil, e + } + + return punk, nil +} + +// AddRef implements the AddRef call for a COM interface pointer. +func (abi *IUnknownABI) AddRef() int32 { + method := unsafe.Slice(abi.Vtbl, 3)[1] + + r, _, _ := syscall.Syscall( + method, + 1, + uintptr(unsafe.Pointer(abi)), + 0, + 0, + ) + + return int32(r) +} + +// Release implements the Release call for a COM interface pointer. +func (abi *IUnknownABI) Release() int32 { + method := unsafe.Slice(abi.Vtbl, 3)[2] + + r, _, _ := syscall.Syscall( + method, + 1, + uintptr(unsafe.Pointer(abi)), + 0, + 0, + ) + + return int32(r) +} diff --git a/vendor/github.com/dblohm7/wingoes/com/mksyscall.go b/vendor/github.com/dblohm7/wingoes/com/mksyscall.go new file mode 100644 index 0000000000..f26be9687a --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/com/mksyscall.go @@ -0,0 +1,25 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows + +package com + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go +//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go + +//sys coCreateInstance(clsid *CLSID, unkOuter *IUnknownABI, clsctx coCLSCTX, iid *IID, ppv **IUnknownABI) (hr wingoes.HRESULT) = ole32.CoCreateInstance +//sys coGetApartmentType(aptType *coAPTTYPE, qual *coAPTTYPEQUALIFIER) (hr wingoes.HRESULT) = ole32.CoGetApartmentType +//sys coInitializeEx(reserved uintptr, flags uint32) (hr wingoes.HRESULT) = ole32.CoInitializeEx +//sys coInitializeSecurity(sd *windows.SECURITY_DESCRIPTOR, authSvcLen int32, authSvc *soleAuthenticationService, reserved1 uintptr, authnLevel rpcAuthnLevel, impLevel rpcImpersonationLevel, authList *soleAuthenticationList, capabilities authCapabilities, reserved2 uintptr) (hr wingoes.HRESULT) = ole32.CoInitializeSecurity + +// We don't use '?' on coIncrementMTAUsage because that doesn't play nicely with HRESULTs. We manually check for its presence in process.go +//sys coIncrementMTAUsage(cookie *coMTAUsageCookie) (hr wingoes.HRESULT) = ole32.CoIncrementMTAUsage + +// Technically this proc is __cdecl, but since it has 0 args this doesn't matter +//sys setOaNoCache() = oleaut32.SetOaNoCache + +// For the following two functions we use IUnknownABI instead of IStreamABI because it makes the callsites cleaner. +//sys shCreateMemStream(pInit *byte, cbInit uint32) (stream *IUnknownABI) = shlwapi.SHCreateMemStream +//sys createStreamOnHGlobal(hglobal internal.HGLOBAL, deleteOnRelease bool, stream **IUnknownABI) (hr wingoes.HRESULT) = ole32.CreateStreamOnHGlobal diff --git a/vendor/github.com/dblohm7/wingoes/com/object.go b/vendor/github.com/dblohm7/wingoes/com/object.go new file mode 100644 index 0000000000..3f18d31512 --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/com/object.go @@ -0,0 +1,89 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows + +package com + +import ( + "fmt" + "unsafe" +) + +// GenericObject is a struct that wraps any interface that implements the COM ABI. +type GenericObject[A ABI] struct { + Pp **A +} + +func (o GenericObject[A]) pp() **A { + return o.Pp +} + +// Object is the interface that all garbage-collected instances of COM interfaces +// must implement. +type Object interface { + // IID returns the interface ID for the object. This method may be called + // on Objects containing the zero value, so its return value must not depend + // on the value of the method's receiver. + IID() *IID + + // Make converts r to an instance of a garbage-collected COM object. The type + // of its return value must always match the type of the method's receiver. + Make(r ABIReceiver) any +} + +// EmbedsGenericObject is a type constraint matching any struct that embeds +// a GenericObject[A]. +type EmbedsGenericObject[A ABI] interface { + Object + ~struct{ GenericObject[A] } + pp() **A +} + +// As casts obj to an object of type O, or panics if obj cannot be converted to O. +func As[O Object, A ABI, PU PUnknown[A], E EmbedsGenericObject[A]](obj E) O { + o, err := TryAs[O, A, PU](obj) + if err != nil { + panic(fmt.Sprintf("wingoes.com.As error: %v", err)) + } + return o +} + +// TryAs casts obj to an object of type O, or returns an error if obj cannot be +// converted to O. +func TryAs[O Object, A ABI, PU PUnknown[A], E EmbedsGenericObject[A]](obj E) (O, error) { + var o O + + iid := o.IID() + p := (PU)(unsafe.Pointer(*(obj.pp()))) + + i, err := p.QueryInterface(iid) + if err != nil { + return o, err + } + + r := NewABIReceiver() + *r = i.(*IUnknownABI) + + return o.Make(r).(O), nil +} + +// IsSameObject returns true when both l and r refer to the same underlying object. +func IsSameObject[AL, AR ABI, PL PUnknown[AL], PR PUnknown[AR], EL EmbedsGenericObject[AL], ER EmbedsGenericObject[AR]](l EL, r ER) bool { + pl := (PL)(unsafe.Pointer(*(l.pp()))) + ul, err := pl.QueryInterface(IID_IUnknown) + if err != nil { + return false + } + defer ul.Release() + + pr := (PR)(unsafe.Pointer(*(r.pp()))) + ur, err := pr.QueryInterface(IID_IUnknown) + if err != nil { + return false + } + defer ur.Release() + + return ul.(*IUnknownABI) == ur.(*IUnknownABI) +} diff --git a/vendor/github.com/dblohm7/wingoes/com/process.go b/vendor/github.com/dblohm7/wingoes/com/process.go new file mode 100644 index 0000000000..8426e462da --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/com/process.go @@ -0,0 +1,268 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows + +package com + +import ( + "os" + "runtime" + + "github.com/dblohm7/wingoes" + "golang.org/x/sys/windows" +) + +// ProcessType is an enumeration that specifies the type of the current process +// when calling StartRuntime. +type ProcessType uint + +const ( + // ConsoleApp is a text-mode Windows program. + ConsoleApp = ProcessType(iota) + // Service is a Windows service. + Service + // GUIApp is a GUI-mode Windows program. + GUIApp + + // Note: Even though this implementation is not yet internally distinguishing + // between console apps and services, this distinction may be useful in the + // future. For example, a service could receive more restrictive default + // security settings than a console app. + // Having this as part of the API now avoids future breakage. +) + +// StartRuntime permanently initializes COM for the remaining lifetime of the +// current process. To avoid errors, it should be called as early as possible +// during program initialization. When processType == GUIApp, the current +// OS thread becomes permanently locked to the current goroutine; any subsequent +// GUI *must* be created on the same OS thread. +// An excellent location to call StartRuntime is in the init function of the +// main package. +func StartRuntime(processType ProcessType) error { + return StartRuntimeWithDACL(processType, nil) +} + +// StartRuntimeWithDACL permanently initializes COM for the remaining lifetime +// of the current process. To avoid errors, it should be called as early as +// possible during program initialization. When processType == GUIApp, the +// current OS thread becomes permanently locked to the current goroutine; any +// subsequent GUI *must* be created on the same OS thread. dacl is an ACL that +// controls access of other processes connecting to the current process over COM. +// For further information about COM access control, look up the COM_RIGHTS_* +// access flags in the Windows developer documentation. +// An excellent location to call StartRuntimeWithDACL is in the init function of +// the main package. +func StartRuntimeWithDACL(processType ProcessType, dacl *windows.ACL) error { + runtime.LockOSThread() + + defer func() { + // When initializing for non-GUI processes, the OS thread may be unlocked + // upon return from this function. + if processType != GUIApp { + runtime.UnlockOSThread() + } + }() + + switch processType { + case ConsoleApp, Service: + // Just start the MTA implicitly. + if err := startMTAImplicitly(); err != nil { + return err + } + case GUIApp: + // For GUIApp, we want the current OS thread to enter a single-threaded + // apartment (STA). However, we want all other OS threads to reside inside + // a multi-threaded apartment (MTA). The way to so this is to first start + // the MTA implicitly, affecting all OS threads who have not yet explicitly + // entered a COM apartment... + if err := startMTAImplicitly(); err != nil { + runtime.UnlockOSThread() + return err + } + // ...and then subsequently explicitly enter a STA on this OS thread, which + // automatically removes this OS thread from the MTA. + if err := enterSTA(); err != nil { + runtime.UnlockOSThread() + return err + } + // From this point forward, we must never unlock the OS thread. + default: + return os.ErrInvalid + } + + // Order is extremely important here: initSecurity must be called immediately + // after apartments are set up, but before doing anything else. + if err := initSecurity(dacl); err != nil { + return err + } + + // By default, for compatibility reasons, COM internally sets a catch-all + // exception handler at its API boundary. This is dangerous, so we override it. + // This work must happen after security settings are initialized, but before + // anything "significant" is done with COM. + globalOpts, err := CreateInstance[GlobalOptions](CLSID_GlobalOptions) + if err != nil { + return err + } + + err = globalOpts.Set(COMGLB_EXCEPTION_HANDLING, COMGLB_EXCEPTION_DONOT_HANDLE_ANY) + + // The BSTR cache never invalidates itself, so we disable it unconditionally. + // We do this here to ensure that the BSTR cache is off before anything + // can possibly start using oleaut32.dll. + setOaNoCache() + + return err +} + +// startMTAImplicitly creates an implicit multi-threaded apartment (MTA) for +// all threads in a process that do not otherwise explicitly enter a COM apartment. +func startMTAImplicitly() error { + // CoIncrementMTAUsage is the modern API to use for creating the MTA implicitly, + // however we may fall back to a legacy mechanism when the former API is unavailable. + if err := procCoIncrementMTAUsage.Find(); err != nil { + return startMTAImplicitlyLegacy() + } + + // We do not retain cookie beyond this function, as we have no intention of + // tearing any of this back down. + var cookie coMTAUsageCookie + hr := coIncrementMTAUsage(&cookie) + if e := wingoes.ErrorFromHRESULT(hr); e.Failed() { + return e + } + + return nil +} + +// startMTAImplicitlyLegacy works by having a background OS thread explicitly enter +// the multi-threaded apartment. All other OS threads that have not explicitly +// entered an apartment will become implicit members of that MTA. This function is +// written assuming that the current OS thread has already been locked. +func startMTAImplicitlyLegacy() error { + // We need to start the MTA on a background OS thread, HOWEVER we also want this + // to happen synchronously, so we wait on c for MTA initialization to complete. + c := make(chan error) + go bgMTASustainer(c) + return <-c +} + +// bgMTASustainer locks the current goroutine to the current OS thread, enters +// the COM multi-threaded apartment, and then blocks for the remainder of the +// process's lifetime. It sends its result to c so that startMTAImplicitlyLegacy +// can wait for the MTA to be ready before proceeding. +func bgMTASustainer(c chan error) { + runtime.LockOSThread() + err := enterMTA() + c <- err + if err != nil { + // We didn't enter the MTA, so just unlock and bail. + runtime.UnlockOSThread() + return + } + select {} +} + +// enterMTA causes the current OS thread to explicitly declare itself to be a +// member of COM's multi-threaded apartment. Note that this function affects +// thread-local state, so use carefully! +func enterMTA() error { + return coInit(windows.COINIT_MULTITHREADED) +} + +// enterSTA causes the current OS thread to create and enter a single-threaded +// apartment. The current OS thread must be locked and remain locked for the +// duration of the thread's time in the apartment. For our purposes, the calling +// OS thread never leaves the STA, so it must effectively remain locked for +// the remaining lifetime of the process. A single-threaded apartment should be +// used if and only if an OS thread is going to be creating windows and pumping +// messages; STAs are NOT generic containers for single-threaded COM code, +// contrary to popular belief. Note that this function affects thread-local +// state, so use carefully! +func enterSTA() error { + return coInit(windows.COINIT_APARTMENTTHREADED) +} + +// coInit is a wrapper for CoInitializeEx that properly handles the S_FALSE +// error code (x/sys/windows.CoInitializeEx does not). +func coInit(apartment uint32) error { + hr := coInitializeEx(0, apartment) + if e := wingoes.ErrorFromHRESULT(hr); e.Failed() { + return e + } + + return nil +} + +const ( + authSvcCOMChooses = -1 +) + +// initSecurity initializes COM security using the ACL specified by dacl. +// A nil dacl implies that a default ACL should be used instead. +func initSecurity(dacl *windows.ACL) error { + sd, err := buildSecurityDescriptor(dacl) + if err != nil { + return err + } + + caps := authCapNone + if sd == nil { + // For COM to fall back to system-wide defaults, we need to set this bit. + caps |= authCapAppID + } + + hr := coInitializeSecurity( + sd, + authSvcCOMChooses, + nil, // authSvc (not used because previous arg is authSvcCOMChooses) + 0, // Reserved, must be 0 + rpcAuthnLevelDefault, + rpcImpLevelIdentify, + nil, // authlist: use defaults + caps, + 0, // Reserved, must be 0 + ) + if e := wingoes.ErrorFromHRESULT(hr); e.Failed() { + return e + } + + return nil +} + +// buildSecurityDescriptor inserts dacl into a valid security descriptor for use +// with CoInitializeSecurity. A nil dacl results in a nil security descriptor, +// which we consider to be a valid "use defaults" sentinel. +func buildSecurityDescriptor(dacl *windows.ACL) (*windows.SECURITY_DESCRIPTOR, error) { + if dacl == nil { + // Not an error, just use defaults. + return nil, nil + } + + sd, err := windows.NewSecurityDescriptor() + if err != nil { + return nil, err + } + + if err := sd.SetDACL(dacl, true, false); err != nil { + return nil, err + } + + // CoInitializeSecurity will fail unless the SD's owner and group are both set. + userSIDs, err := wingoes.CurrentProcessUserSIDs() + if err != nil { + return nil, err + } + + if err := sd.SetOwner(userSIDs.User, false); err != nil { + return nil, err + } + + if err := sd.SetGroup(userSIDs.PrimaryGroup, false); err != nil { + return nil, err + } + + return sd, nil +} diff --git a/vendor/github.com/dblohm7/wingoes/com/stream.go b/vendor/github.com/dblohm7/wingoes/com/stream.go new file mode 100644 index 0000000000..5814e1719a --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/com/stream.go @@ -0,0 +1,556 @@ +// Copyright (c) 2023 Tailscale Inc & AUTHORS. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows + +package com + +import ( + "io" + "runtime" + "syscall" + "unsafe" + + "github.com/dblohm7/wingoes" + "github.com/dblohm7/wingoes/internal" + "golang.org/x/sys/windows" +) + +var ( + IID_ISequentialStream = &IID{0x0C733A30, 0x2A1C, 0x11CE, [8]byte{0xAD, 0xE5, 0x00, 0xAA, 0x00, 0x44, 0x77, 0x3D}} + IID_IStream = &IID{0x0000000C, 0x0000, 0x0000, [8]byte{0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46}} +) + +type STGC uint32 + +const ( + STGC_DEFAULT = STGC(0) + STGC_OVERWRITE = STGC(1) + STGC_ONLYIFCURRENT = STGC(2) + STGC_DANGEROUSLYCOMMITMERELYTODISKCACHE = STGC(4) + STGC_CONSOLIDATE = STGC(8) +) + +type LOCKTYPE uint32 + +const ( + LOCK_WRITE = LOCKTYPE(1) + LOCK_EXCLUSIVE = LOCKTYPE(2) + LOCK_ONLYONCE = LOCKTYPE(4) +) + +type STGTY uint32 + +const ( + STGTY_STORAGE = STGTY(1) + STGTY_STREAM = STGTY(2) + STGTY_LOCKBYTES = STGTY(3) + STGTY_PROPERTY = STGTY(4) +) + +type STATFLAG uint32 + +const ( + STATFLAG_DEFAULT = STATFLAG(0) + STATFLAG_NONAME = STATFLAG(1) + STATFLAG_NOOPEN = STATFLAG(2) +) + +type STATSTG struct { + Name COMAllocatedString + Type STGTY + Size uint64 + MTime windows.Filetime + CTime windows.Filetime + ATime windows.Filetime + Mode uint32 + LocksSupported LOCKTYPE + ClsID CLSID + _ uint32 // StateBits + _ uint32 // reserved +} + +func (st *STATSTG) Close() error { + return st.Name.Close() +} + +type ISequentialStreamABI struct { + IUnknownABI +} + +type IStreamABI struct { + ISequentialStreamABI +} + +type SequentialStream struct { + GenericObject[ISequentialStreamABI] +} + +type Stream struct { + GenericObject[IStreamABI] +} + +func (abi *ISequentialStreamABI) Read(p []byte) (int, error) { + if len(p) > maxStreamRWLen { + p = p[:maxStreamRWLen] + } + + var cbRead uint32 + method := unsafe.Slice(abi.Vtbl, 5)[3] + + rc, _, _ := syscall.SyscallN( + method, + uintptr(unsafe.Pointer(abi)), + uintptr(unsafe.Pointer(&p[0])), + uintptr(uint32(len(p))), + uintptr(unsafe.Pointer(&cbRead)), + ) + n := int(cbRead) + e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(rc)) + if e.Failed() { + return n, e + } + + // Various implementations of IStream handle EOF differently. We need to + // deal with both. + if e.AsHRESULT() == wingoes.S_FALSE || (n == 0 && len(p) > 0) { + return n, io.EOF + } + + return n, nil +} + +func (abi *ISequentialStreamABI) Write(p []byte) (int, error) { + w := p + if len(w) > maxStreamRWLen { + w = w[:maxStreamRWLen] + } + + var cbWritten uint32 + method := unsafe.Slice(abi.Vtbl, 5)[4] + + rc, _, _ := syscall.SyscallN( + method, + uintptr(unsafe.Pointer(abi)), + uintptr(unsafe.Pointer(&w[0])), + uintptr(uint32(len(w))), + uintptr(unsafe.Pointer(&cbWritten)), + ) + n := int(cbWritten) + if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(rc)); e.Failed() { + return n, e + } + + // Need this to satisfy Writer. + if n < len(p) { + return n, io.ErrShortWrite + } + + return n, nil +} + +func (o SequentialStream) IID() *IID { + return IID_ISequentialStream +} + +func (o SequentialStream) Make(r ABIReceiver) any { + if r == nil { + return SequentialStream{} + } + + runtime.SetFinalizer(r, ReleaseABI) + + pp := (**ISequentialStreamABI)(unsafe.Pointer(r)) + return SequentialStream{GenericObject[ISequentialStreamABI]{Pp: pp}} +} + +func (o SequentialStream) UnsafeUnwrap() *ISequentialStreamABI { + return *(o.Pp) +} + +func (o SequentialStream) Read(b []byte) (n int, err error) { + p := *(o.Pp) + return p.Read(b) +} + +func (o SequentialStream) Write(b []byte) (int, error) { + p := *(o.Pp) + return p.Write(b) +} + +func (abi *IStreamABI) Seek(offset int64, whence int) (n int64, _ error) { + var hr wingoes.HRESULT + method := unsafe.Slice(abi.Vtbl, 14)[5] + + if runtime.GOARCH == "386" { + words := (*[2]uintptr)(unsafe.Pointer(&offset)) + rc, _, _ := syscall.SyscallN( + method, + uintptr(unsafe.Pointer(abi)), + words[0], + words[1], + uintptr(uint32(whence)), + uintptr(unsafe.Pointer(&n)), + ) + hr = wingoes.HRESULT(rc) + } else { + rc, _, _ := syscall.SyscallN( + method, + uintptr(unsafe.Pointer(abi)), + uintptr(offset), + uintptr(uint32(whence)), + uintptr(unsafe.Pointer(&n)), + ) + hr = wingoes.HRESULT(rc) + } + + if e := wingoes.ErrorFromHRESULT(hr); e.Failed() { + return 0, e + } + + return n, nil +} + +func (abi *IStreamABI) SetSize(newSize uint64) error { + var hr wingoes.HRESULT + method := unsafe.Slice(abi.Vtbl, 14)[6] + + if runtime.GOARCH == "386" { + words := (*[2]uintptr)(unsafe.Pointer(&newSize)) + rc, _, _ := syscall.SyscallN( + method, + uintptr(unsafe.Pointer(abi)), + words[0], + words[1], + ) + hr = wingoes.HRESULT(rc) + } else { + rc, _, _ := syscall.SyscallN( + method, + uintptr(unsafe.Pointer(abi)), + uintptr(newSize), + ) + hr = wingoes.HRESULT(rc) + } + + if e := wingoes.ErrorFromHRESULT(hr); e.Failed() { + return e + } + + return nil +} + +func (abi *IStreamABI) CopyTo(dest *IStreamABI, numBytesToCopy uint64) (bytesRead, bytesWritten uint64, _ error) { + var hr wingoes.HRESULT + method := unsafe.Slice(abi.Vtbl, 14)[7] + + if runtime.GOARCH == "386" { + words := (*[2]uintptr)(unsafe.Pointer(&numBytesToCopy)) + rc, _, _ := syscall.SyscallN( + method, + uintptr(unsafe.Pointer(abi)), + uintptr(unsafe.Pointer(dest)), + words[0], + words[1], + uintptr(unsafe.Pointer(&bytesRead)), + uintptr(unsafe.Pointer(&bytesWritten)), + ) + hr = wingoes.HRESULT(rc) + } else { + rc, _, _ := syscall.SyscallN( + method, + uintptr(unsafe.Pointer(abi)), + uintptr(unsafe.Pointer(dest)), + uintptr(numBytesToCopy), + uintptr(unsafe.Pointer(&bytesRead)), + uintptr(unsafe.Pointer(&bytesWritten)), + ) + hr = wingoes.HRESULT(rc) + } + + if e := wingoes.ErrorFromHRESULT(hr); e.Failed() { + return bytesRead, bytesWritten, e + } + + return bytesRead, bytesWritten, nil +} + +func (abi *IStreamABI) Commit(flags STGC) error { + method := unsafe.Slice(abi.Vtbl, 14)[8] + + rc, _, _ := syscall.SyscallN( + method, + uintptr(unsafe.Pointer(abi)), + uintptr(flags), + ) + if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(rc)); e.Failed() { + return e + } + + return nil +} + +func (abi *IStreamABI) Revert() error { + method := unsafe.Slice(abi.Vtbl, 14)[9] + + rc, _, _ := syscall.SyscallN( + method, + uintptr(unsafe.Pointer(abi)), + ) + + if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(rc)); e.Failed() { + return e + } + + return nil +} + +func (abi *IStreamABI) LockRegion(offset, numBytes uint64, lockType LOCKTYPE) error { + var hr wingoes.HRESULT + method := unsafe.Slice(abi.Vtbl, 14)[10] + + if runtime.GOARCH == "386" { + oWords := (*[2]uintptr)(unsafe.Pointer(&offset)) + nWords := (*[2]uintptr)(unsafe.Pointer(&numBytes)) + rc, _, _ := syscall.SyscallN( + method, + uintptr(unsafe.Pointer(abi)), + oWords[0], + oWords[1], + nWords[0], + nWords[1], + uintptr(lockType), + ) + hr = wingoes.HRESULT(rc) + } else { + rc, _, _ := syscall.SyscallN( + method, + uintptr(unsafe.Pointer(abi)), + uintptr(offset), + uintptr(numBytes), + uintptr(lockType), + ) + hr = wingoes.HRESULT(rc) + } + + if e := wingoes.ErrorFromHRESULT(hr); e.Failed() { + return e + } + + return nil +} + +func (abi *IStreamABI) UnlockRegion(offset, numBytes uint64, lockType LOCKTYPE) error { + var hr wingoes.HRESULT + method := unsafe.Slice(abi.Vtbl, 14)[11] + + if runtime.GOARCH == "386" { + oWords := (*[2]uintptr)(unsafe.Pointer(&offset)) + nWords := (*[2]uintptr)(unsafe.Pointer(&numBytes)) + rc, _, _ := syscall.SyscallN( + method, + uintptr(unsafe.Pointer(abi)), + oWords[0], + oWords[1], + nWords[0], + nWords[1], + uintptr(lockType), + ) + hr = wingoes.HRESULT(rc) + } else { + rc, _, _ := syscall.SyscallN( + method, + uintptr(unsafe.Pointer(abi)), + uintptr(offset), + uintptr(numBytes), + uintptr(lockType), + ) + hr = wingoes.HRESULT(rc) + } + + if e := wingoes.ErrorFromHRESULT(hr); e.Failed() { + return e + } + + return nil +} + +func (abi *IStreamABI) Stat(flags STATFLAG) (*STATSTG, error) { + result := new(STATSTG) + method := unsafe.Slice(abi.Vtbl, 14)[12] + + rc, _, _ := syscall.SyscallN( + method, + uintptr(unsafe.Pointer(abi)), + uintptr(unsafe.Pointer(result)), + uintptr(flags), + ) + if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(rc)); e.Failed() { + return nil, e + } + + return result, nil +} + +func (abi *IStreamABI) Clone() (result *IUnknownABI, _ error) { + method := unsafe.Slice(abi.Vtbl, 14)[13] + + rc, _, _ := syscall.SyscallN( + method, + uintptr(unsafe.Pointer(abi)), + uintptr(unsafe.Pointer(&result)), + ) + if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(rc)); e.Failed() { + return nil, e + } + + return result, nil +} + +func (o Stream) IID() *IID { + return IID_IStream +} + +func (o Stream) Make(r ABIReceiver) any { + if r == nil { + return Stream{} + } + + runtime.SetFinalizer(r, ReleaseABI) + + pp := (**IStreamABI)(unsafe.Pointer(r)) + return Stream{GenericObject[IStreamABI]{Pp: pp}} +} + +func (o Stream) UnsafeUnwrap() *IStreamABI { + return *(o.Pp) +} + +func (o Stream) Read(buf []byte) (int, error) { + p := *(o.Pp) + return p.Read(buf) +} + +func (o Stream) Write(buf []byte) (int, error) { + p := *(o.Pp) + return p.Write(buf) +} + +func (o Stream) Seek(offset int64, whence int) (n int64, _ error) { + p := *(o.Pp) + return p.Seek(offset, whence) +} + +func (o Stream) SetSize(newSize uint64) error { + p := *(o.Pp) + return p.SetSize(newSize) +} + +func (o Stream) CopyTo(dest Stream, numBytesToCopy uint64) (bytesRead, bytesWritten uint64, _ error) { + p := *(o.Pp) + return p.CopyTo(dest.UnsafeUnwrap(), numBytesToCopy) +} + +func (o Stream) Commit(flags STGC) error { + p := *(o.Pp) + return p.Commit(flags) +} + +func (o Stream) Revert() error { + p := *(o.Pp) + return p.Revert() +} + +func (o Stream) LockRegion(offset, numBytes uint64, lockType LOCKTYPE) error { + p := *(o.Pp) + return p.LockRegion(offset, numBytes, lockType) +} + +func (o Stream) UnlockRegion(offset, numBytes uint64, lockType LOCKTYPE) error { + p := *(o.Pp) + return p.UnlockRegion(offset, numBytes, lockType) +} + +func (o Stream) Stat(flags STATFLAG) (*STATSTG, error) { + p := *(o.Pp) + return p.Stat(flags) +} + +func (o Stream) Clone() (result Stream, _ error) { + p := *(o.Pp) + punk, err := p.Clone() + if err != nil { + return result, err + } + + return result.Make(&punk).(Stream), nil +} + +const hrE_OUTOFMEMORY = wingoes.HRESULT(-((0x8007000E ^ 0xFFFFFFFF) + 1)) + +// NewMemoryStream creates a new in-memory Stream object initially containing a +// copy of initialBytes. Its seek pointer is guaranteed to reference the +// beginning of the stream. +func NewMemoryStream(initialBytes []byte) (result Stream, _ error) { + return newMemoryStreamInternal(initialBytes, false) +} + +func newMemoryStreamInternal(initialBytes []byte, forceLegacy bool) (result Stream, _ error) { + if len(initialBytes) > maxStreamRWLen { + return result, wingoes.ErrorFromHRESULT(hrE_OUTOFMEMORY) + } + + // SHCreateMemStream exists on Win7 but is not safe for us to use until Win8. + if forceLegacy || !wingoes.IsWin8OrGreater() { + return newMemoryStreamLegacy(initialBytes) + } + + var base *byte + var length uint32 + if l := uint32(len(initialBytes)); l > 0 { + base = &initialBytes[0] + length = l + } + + punk := shCreateMemStream(base, length) + if punk == nil { + return result, wingoes.ErrorFromHRESULT(hrE_OUTOFMEMORY) + } + + obj := result.Make(&punk).(Stream) + if _, err := obj.Seek(0, io.SeekStart); err != nil { + return result, err + } + + return obj, nil +} + +func newMemoryStreamLegacy(initialBytes []byte) (result Stream, _ error) { + ppstream := NewABIReceiver() + hr := createStreamOnHGlobal(internal.HGLOBAL(0), true, ppstream) + if e := wingoes.ErrorFromHRESULT(hr); e.Failed() { + return result, e + } + + obj := result.Make(ppstream).(Stream) + + if err := obj.SetSize(uint64(len(initialBytes))); err != nil { + return result, err + } + + if len(initialBytes) == 0 { + return obj, nil + } + + _, err := obj.Write(initialBytes) + if err != nil { + return result, err + } + + if _, err := obj.Seek(0, io.SeekStart); err != nil { + return result, err + } + + return obj, nil +} diff --git a/vendor/github.com/dblohm7/wingoes/com/stream_not386.go b/vendor/github.com/dblohm7/wingoes/com/stream_not386.go new file mode 100644 index 0000000000..7b82a40226 --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/com/stream_not386.go @@ -0,0 +1,13 @@ +// Copyright (c) 2023 Tailscale Inc & AUTHORS. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows && !386 + +package com + +import ( + "math" +) + +const maxStreamRWLen = math.MaxUint32 diff --git a/vendor/github.com/dblohm7/wingoes/com/stream_windows_386.go b/vendor/github.com/dblohm7/wingoes/com/stream_windows_386.go new file mode 100644 index 0000000000..c68e24ec34 --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/com/stream_windows_386.go @@ -0,0 +1,11 @@ +// Copyright (c) 2023 Tailscale Inc & AUTHORS. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package com + +import ( + "math" +) + +const maxStreamRWLen = math.MaxInt32 diff --git a/vendor/github.com/dblohm7/wingoes/com/types.go b/vendor/github.com/dblohm7/wingoes/com/types.go new file mode 100644 index 0000000000..a965ba9256 --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/com/types.go @@ -0,0 +1,166 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows + +package com + +import ( + "unsafe" + + "github.com/dblohm7/wingoes" + "golang.org/x/sys/windows" +) + +// IID is a GUID that represents an interface ID. +type IID windows.GUID + +// CLSID is a GUID that represents a class ID. +type CLSID windows.GUID + +// AppID is a GUID that represents an application ID. +type AppID windows.GUID + +// ServiceID is a GUID that represents a service ID. +type ServiceID windows.GUID + +type coMTAUsageCookie windows.Handle + +type coCLSCTX uint32 + +const ( + // We intentionally do not define combinations of these values, as in my experience + // people don't realize what they're doing when they use those. + coCLSCTX_INPROC_SERVER = coCLSCTX(0x1) + coCLSCTX_LOCAL_SERVER = coCLSCTX(0x4) + coCLSCTX_REMOTE_SERVER = coCLSCTX(0x10) +) + +type coAPTTYPE int32 + +const ( + coAPTTYPE_CURRENT = coAPTTYPE(-1) + coAPTTYPE_STA = coAPTTYPE(0) + coAPTTYPE_MTA = coAPTTYPE(1) + coAPTTYPE_NA = coAPTTYPE(2) + coAPTTYPE_MAINSTA = coAPTTYPE(3) +) + +type coAPTTYPEQUALIFIER int32 + +const ( + coAPTTYPEQUALIFIER_NONE = coAPTTYPEQUALIFIER(0) + coAPTTYPEQUALIFIER_IMPLICIT_MTA = coAPTTYPEQUALIFIER(1) + coAPTTYPEQUALIFIER_NA_ON_MTA = coAPTTYPEQUALIFIER(2) + coAPTTYPEQUALIFIER_NA_ON_STA = coAPTTYPEQUALIFIER(3) + coAPTTYPEQUALIFIER_NA_ON_IMPLICIT_MTA = coAPTTYPEQUALIFIER(4) + coAPTTYPEQUALIFIER_NA_ON_MAINSTA = coAPTTYPEQUALIFIER(5) + coAPTTYPEQUALIFIER_APPLICATION_STA = coAPTTYPEQUALIFIER(6) +) + +type aptInfo struct { + apt coAPTTYPE + qualifier coAPTTYPEQUALIFIER +} + +type soleAuthenticationInfo struct { + authnSvc uint32 + authzSvc uint32 + authInfo uintptr +} + +type soleAuthenticationList struct { + count uint32 + authInfo *soleAuthenticationInfo +} + +type soleAuthenticationService struct { + authnSvc uint32 + authzSvc uint32 + principalName *uint16 + hr wingoes.HRESULT +} + +type authCapabilities uint32 + +const ( + authCapNone = authCapabilities(0) + authCapMutualAuth = authCapabilities(1) + authCapSecureRefs = authCapabilities(2) + authCapAccessControl = authCapabilities(4) + authCapAppID = authCapabilities(8) + authCapDynamic = authCapabilities(0x10) + authCapStaticCloaking = authCapabilities(0x20) + authCapDynamicCloaking = authCapabilities(0x40) + authCapAnyAuthority = authCapabilities(0x80) + authCapMakeFullsic = authCapabilities(0x100) + authCapRequireFullsic = authCapabilities(0x200) + authCapAutoImpersonate = authCapabilities(0x400) + authCapDefault = authCapabilities(0x800) + authCapDisableAAA = authCapabilities(0x1000) + authCapNoCustomMarshal = authCapabilities(0x2000) +) + +type rpcAuthnLevel uint32 + +const ( + rpcAuthnLevelDefault = rpcAuthnLevel(0) + rpcAuthnLevelNone = rpcAuthnLevel(1) + rpcAuthnLevelConnect = rpcAuthnLevel(2) + rpcAuthnLevelCall = rpcAuthnLevel(3) + rpcAuthnLevelPkt = rpcAuthnLevel(4) + rpcAuthnLevelPktIntegrity = rpcAuthnLevel(5) + rpcAuthnLevelPkgPrivacy = rpcAuthnLevel(6) +) + +type rpcImpersonationLevel uint32 + +const ( + rpcImpLevelDefault = rpcImpersonationLevel(0) + rpcImpLevelAnonymous = rpcImpersonationLevel(1) + rpcImpLevelIdentify = rpcImpersonationLevel(2) + rpcImpLevelImpersonate = rpcImpersonationLevel(3) + rpcImpLevelDelegate = rpcImpersonationLevel(4) +) + +// COMAllocatedString encapsulates a UTF-16 string that was allocated by COM +// using its internal heap. +type COMAllocatedString uintptr + +// Close frees the memory held by the string. +func (s *COMAllocatedString) Close() error { + windows.CoTaskMemFree(unsafe.Pointer(*s)) + *s = 0 + return nil +} + +func (s *COMAllocatedString) String() string { + return windows.UTF16PtrToString((*uint16)(unsafe.Pointer(*s))) +} + +// UTF16 returns a slice containing a copy of the UTF-16 string, including a +// NUL terminator. +func (s *COMAllocatedString) UTF16() []uint16 { + p := (*uint16)(unsafe.Pointer(*s)) + if p == nil { + return nil + } + + n := 0 + for ptr := unsafe.Pointer(p); *(*uint16)(ptr) != 0; n++ { + ptr = unsafe.Pointer(uintptr(ptr) + unsafe.Sizeof(*p)) + } + + // Make a copy, including the NUL terminator. + return append([]uint16{}, unsafe.Slice(p, n+1)...) +} + +// UTF16Ptr returns a pointer to a NUL-terminated copy of the UTF-16 string. +func (s *COMAllocatedString) UTF16Ptr() *uint16 { + if slc := s.UTF16(); slc != nil { + return &slc[0] + } + + return nil +} diff --git a/vendor/github.com/dblohm7/wingoes/com/unknown.go b/vendor/github.com/dblohm7/wingoes/com/unknown.go new file mode 100644 index 0000000000..c1fb64c5f9 --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/com/unknown.go @@ -0,0 +1,44 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows + +package com + +import ( + "runtime" +) + +var ( + IID_IUnknown = &IID{0x00000000, 0x0000, 0x0000, [8]byte{0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46}} +) + +// ObjectBase is a garbage-collected instance of any COM object's base interface. +type ObjectBase struct { + GenericObject[IUnknownABI] +} + +// IID always returns IID_IUnknown. +func (o ObjectBase) IID() *IID { + return IID_IUnknown +} + +// Make produces a new instance of ObjectBase that wraps r. Its return type is +// always ObjectBase. +func (o ObjectBase) Make(r ABIReceiver) any { + if r == nil { + return ObjectBase{} + } + + runtime.SetFinalizer(r, ReleaseABI) + + pp := (**IUnknownABI)(r) + return ObjectBase{GenericObject[IUnknownABI]{Pp: pp}} +} + +// UnsafeUnwrap returns the underlying IUnknownABI of the object. As the name +// implies, this is unsafe -- you had better know what you are doing! +func (o ObjectBase) UnsafeUnwrap() *IUnknownABI { + return *(o.Pp) +} diff --git a/vendor/github.com/dblohm7/wingoes/com/zsyscall_windows.go b/vendor/github.com/dblohm7/wingoes/com/zsyscall_windows.go new file mode 100644 index 0000000000..d057da5581 --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/com/zsyscall_windows.go @@ -0,0 +1,106 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package com + +import ( + "syscall" + "unsafe" + + "github.com/dblohm7/wingoes" + "github.com/dblohm7/wingoes/internal" + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modole32 = windows.NewLazySystemDLL("ole32.dll") + modoleaut32 = windows.NewLazySystemDLL("oleaut32.dll") + modshlwapi = windows.NewLazySystemDLL("shlwapi.dll") + + procCoCreateInstance = modole32.NewProc("CoCreateInstance") + procCoGetApartmentType = modole32.NewProc("CoGetApartmentType") + procCoIncrementMTAUsage = modole32.NewProc("CoIncrementMTAUsage") + procCoInitializeEx = modole32.NewProc("CoInitializeEx") + procCoInitializeSecurity = modole32.NewProc("CoInitializeSecurity") + procCreateStreamOnHGlobal = modole32.NewProc("CreateStreamOnHGlobal") + procSetOaNoCache = modoleaut32.NewProc("SetOaNoCache") + procSHCreateMemStream = modshlwapi.NewProc("SHCreateMemStream") +) + +func coCreateInstance(clsid *CLSID, unkOuter *IUnknownABI, clsctx coCLSCTX, iid *IID, ppv **IUnknownABI) (hr wingoes.HRESULT) { + r0, _, _ := syscall.Syscall6(procCoCreateInstance.Addr(), 5, uintptr(unsafe.Pointer(clsid)), uintptr(unsafe.Pointer(unkOuter)), uintptr(clsctx), uintptr(unsafe.Pointer(iid)), uintptr(unsafe.Pointer(ppv)), 0) + hr = wingoes.HRESULT(r0) + return +} + +func coGetApartmentType(aptType *coAPTTYPE, qual *coAPTTYPEQUALIFIER) (hr wingoes.HRESULT) { + r0, _, _ := syscall.Syscall(procCoGetApartmentType.Addr(), 2, uintptr(unsafe.Pointer(aptType)), uintptr(unsafe.Pointer(qual)), 0) + hr = wingoes.HRESULT(r0) + return +} + +func coIncrementMTAUsage(cookie *coMTAUsageCookie) (hr wingoes.HRESULT) { + r0, _, _ := syscall.Syscall(procCoIncrementMTAUsage.Addr(), 1, uintptr(unsafe.Pointer(cookie)), 0, 0) + hr = wingoes.HRESULT(r0) + return +} + +func coInitializeEx(reserved uintptr, flags uint32) (hr wingoes.HRESULT) { + r0, _, _ := syscall.Syscall(procCoInitializeEx.Addr(), 2, uintptr(reserved), uintptr(flags), 0) + hr = wingoes.HRESULT(r0) + return +} + +func coInitializeSecurity(sd *windows.SECURITY_DESCRIPTOR, authSvcLen int32, authSvc *soleAuthenticationService, reserved1 uintptr, authnLevel rpcAuthnLevel, impLevel rpcImpersonationLevel, authList *soleAuthenticationList, capabilities authCapabilities, reserved2 uintptr) (hr wingoes.HRESULT) { + r0, _, _ := syscall.Syscall9(procCoInitializeSecurity.Addr(), 9, uintptr(unsafe.Pointer(sd)), uintptr(authSvcLen), uintptr(unsafe.Pointer(authSvc)), uintptr(reserved1), uintptr(authnLevel), uintptr(impLevel), uintptr(unsafe.Pointer(authList)), uintptr(capabilities), uintptr(reserved2)) + hr = wingoes.HRESULT(r0) + return +} + +func createStreamOnHGlobal(hglobal internal.HGLOBAL, deleteOnRelease bool, stream **IUnknownABI) (hr wingoes.HRESULT) { + var _p0 uint32 + if deleteOnRelease { + _p0 = 1 + } + r0, _, _ := syscall.Syscall(procCreateStreamOnHGlobal.Addr(), 3, uintptr(hglobal), uintptr(_p0), uintptr(unsafe.Pointer(stream))) + hr = wingoes.HRESULT(r0) + return +} + +func setOaNoCache() { + syscall.Syscall(procSetOaNoCache.Addr(), 0, 0, 0, 0) + return +} + +func shCreateMemStream(pInit *byte, cbInit uint32) (stream *IUnknownABI) { + r0, _, _ := syscall.Syscall(procSHCreateMemStream.Addr(), 2, uintptr(unsafe.Pointer(pInit)), uintptr(cbInit), 0) + stream = (*IUnknownABI)(unsafe.Pointer(r0)) + return +} diff --git a/vendor/github.com/dblohm7/wingoes/error.go b/vendor/github.com/dblohm7/wingoes/error.go new file mode 100644 index 0000000000..9516af8f15 --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/error.go @@ -0,0 +1,311 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows + +package wingoes + +import ( + "fmt" + + "golang.org/x/sys/windows" +) + +// HRESULT is equivalent to the HRESULT type in the Win32 SDK for C/C++. +type HRESULT int32 + +// Error represents various error codes that may be encountered when coding +// against Windows APIs, including HRESULTs, windows.NTStatus, and windows.Errno. +type Error HRESULT + +// Errors are HRESULTs under the hood because the HRESULT encoding allows for +// all the other common types of Windows errors to be encoded within them. + +const ( + hrS_OK = HRESULT(0) + hrE_ABORT = HRESULT(-((0x80004004 ^ 0xFFFFFFFF) + 1)) + hrE_FAIL = HRESULT(-((0x80004005 ^ 0xFFFFFFFF) + 1)) + hrE_NOINTERFACE = HRESULT(-((0x80004002 ^ 0xFFFFFFFF) + 1)) + hrE_NOTIMPL = HRESULT(-((0x80004001 ^ 0xFFFFFFFF) + 1)) + hrE_POINTER = HRESULT(-((0x80004003 ^ 0xFFFFFFFF) + 1)) + hrE_UNEXPECTED = HRESULT(-((0x8000FFFF ^ 0xFFFFFFFF) + 1)) + hrTYPE_E_WRONGTYPEKIND = HRESULT(-((0x8002802A ^ 0xFFFFFFFF) + 1)) +) + +// S_FALSE is a peculiar HRESULT value which means that the call executed +// successfully, but returned false as its result. +const S_FALSE = HRESULT(1) + +var ( + // genericError encodes an Error whose message string is very generic. + genericError = Error(hresultFromFacilityAndCode(hrFail, facilityWin32, hrCode(windows.ERROR_UNIDENTIFIED_ERROR))) +) + +// Common HRESULT codes that don't use Win32 facilities, but have meanings that +// we can manually translate to Win32 error codes. +var commonHRESULTToErrno = map[HRESULT]windows.Errno{ + hrE_ABORT: windows.ERROR_REQUEST_ABORTED, + hrE_FAIL: windows.ERROR_UNIDENTIFIED_ERROR, + hrE_NOINTERFACE: windows.ERROR_NOINTERFACE, + hrE_NOTIMPL: windows.ERROR_CALL_NOT_IMPLEMENTED, + hrE_UNEXPECTED: windows.ERROR_INTERNAL_ERROR, +} + +type hrCode uint16 +type hrFacility uint16 +type failBit bool + +const ( + hrFlagBitsMask = 0xF8000000 + hrFacilityMax = 0x00001FFF + hrFacilityMask = hrFacilityMax << 16 + hrCodeMax = 0x0000FFFF + hrCodeMask = hrCodeMax + hrFailBit = 0x80000000 + hrCustomerBit = 0x20000000 // Also defined as syscall.APPLICATION_ERROR + hrFacilityNTBit = 0x10000000 +) + +const ( + facilityWin32 = hrFacility(7) +) + +// Succeeded returns true when hr is successful, but its actual error code +// may include additional status information. +func (hr HRESULT) Succeeded() bool { + return hr >= 0 +} + +// Failed returns true when hr contains a failure code. +func (hr HRESULT) Failed() bool { + return hr < 0 +} + +func (hr HRESULT) isNT() bool { + return (hr & (hrCustomerBit | hrFacilityNTBit)) == hrFacilityNTBit +} + +func (hr HRESULT) isCustomer() bool { + return (hr & hrCustomerBit) != 0 +} + +// isNormal returns true when the customer and NT bits are cleared, ie hr's +// encoding contains valid facility and code fields. +func (hr HRESULT) isNormal() bool { + return (hr & (hrCustomerBit | hrFacilityNTBit)) == 0 +} + +// facility returns the facility bits of hr. Only valid when isNormal is true. +func (hr HRESULT) facility() hrFacility { + return hrFacility((uint32(hr) >> 16) & hrFacilityMax) +} + +// facility returns the code bits of hr. Only valid when isNormal is true. +func (hr HRESULT) code() hrCode { + return hrCode(uint32(hr) & hrCodeMask) +} + +const ( + hrFail = failBit(true) + hrSuccess = failBit(false) +) + +func hresultFromFacilityAndCode(isFail failBit, f hrFacility, c hrCode) HRESULT { + var r uint32 + if isFail { + r |= hrFailBit + } + r |= (uint32(f) << 16) & hrFacilityMask + r |= uint32(c) & hrCodeMask + return HRESULT(r) +} + +// ErrorFromErrno creates an Error from e. +func ErrorFromErrno(e windows.Errno) Error { + if e == windows.ERROR_SUCCESS { + return Error(hrS_OK) + } + if ue := uint32(e); (ue & hrFlagBitsMask) == hrCustomerBit { + // syscall.APPLICATION_ERROR == hrCustomerBit, so the only other thing + // we need to do to transform this into an HRESULT is add the fail flag + return Error(HRESULT(ue | hrFailBit)) + } + if uint32(e) > hrCodeMax { + // Can't be encoded in HRESULT, return generic error instead + return genericError + } + return Error(hresultFromFacilityAndCode(hrFail, facilityWin32, hrCode(e))) +} + +// ErrorFromNTStatus creates an Error from s. +func ErrorFromNTStatus(s windows.NTStatus) Error { + if s == windows.STATUS_SUCCESS { + return Error(hrS_OK) + } + return Error(HRESULT(s) | hrFacilityNTBit) +} + +// ErrorFromHRESULT creates an Error from hr. +func ErrorFromHRESULT(hr HRESULT) Error { + return Error(hr) +} + +// NewError converts e into an Error if e's type is supported. It returns +// both the Error and a bool indicating whether the conversion was successful. +func NewError(e any) (Error, bool) { + switch v := e.(type) { + case Error: + return v, true + case windows.NTStatus: + return ErrorFromNTStatus(v), true + case windows.Errno: + return ErrorFromErrno(v), true + case HRESULT: + return ErrorFromHRESULT(v), true + default: + return ErrorFromHRESULT(hrTYPE_E_WRONGTYPEKIND), false + } +} + +// IsOK returns true when the Error is unconditionally successful. +func (e Error) IsOK() bool { + return HRESULT(e) == hrS_OK +} + +// Succeeded returns true when the Error is successful, but its error code +// may include additional status information. +func (e Error) Succeeded() bool { + return HRESULT(e).Succeeded() +} + +// Failed returns true when the Error contains a failure code. +func (e Error) Failed() bool { + return HRESULT(e).Failed() +} + +// AsHRESULT converts the Error to a HRESULT. +func (e Error) AsHRESULT() HRESULT { + return HRESULT(e) +} + +type errnoFailHandler func(hr HRESULT) windows.Errno + +func (e Error) toErrno(f errnoFailHandler) windows.Errno { + hr := HRESULT(e) + + if hr == hrS_OK { + return windows.ERROR_SUCCESS + } + + if hr.isCustomer() { + return windows.Errno(uint32(e) ^ hrFailBit) + } + + if hr.isNT() { + return e.AsNTStatus().Errno() + } + + if hr.facility() == facilityWin32 { + return windows.Errno(hr.code()) + } + + if errno, ok := commonHRESULTToErrno[hr]; ok { + return errno + } + + return f(hr) +} + +// AsError converts the Error to a windows.Errno, but panics if not possible. +func (e Error) AsErrno() windows.Errno { + handler := func(hr HRESULT) windows.Errno { + panic(fmt.Sprintf("wingoes.Error: Called AsErrno on a non-convertable HRESULT 0x%08X", uint32(hr))) + return windows.ERROR_UNIDENTIFIED_ERROR + } + + return e.toErrno(handler) +} + +type ntStatusFailHandler func(hr HRESULT) windows.NTStatus + +func (e Error) toNTStatus(f ntStatusFailHandler) windows.NTStatus { + hr := HRESULT(e) + + if hr == hrS_OK { + return windows.STATUS_SUCCESS + } + + if hr.isNT() { + return windows.NTStatus(hr ^ hrFacilityNTBit) + } + + return f(hr) +} + +// AsNTStatus converts the Error to a windows.NTStatus, but panics if not possible. +func (e Error) AsNTStatus() windows.NTStatus { + handler := func(hr HRESULT) windows.NTStatus { + panic(fmt.Sprintf("windows.Error: Called AsNTStatus on a non-NTSTATUS HRESULT 0x%08X", uint32(hr))) + return windows.STATUS_UNSUCCESSFUL + } + + return e.toNTStatus(handler) +} + +// TryAsErrno converts the Error to a windows.Errno, or returns defval if +// such a conversion is not possible. +func (e Error) TryAsErrno(defval windows.Errno) windows.Errno { + handler := func(hr HRESULT) windows.Errno { + return defval + } + + return e.toErrno(handler) +} + +// TryAsNTStatus converts the Error to a windows.NTStatus, or returns defval if +// such a conversion is not possible. +func (e Error) TryAsNTStatus(defval windows.NTStatus) windows.NTStatus { + handler := func(hr HRESULT) windows.NTStatus { + return defval + } + + return e.toNTStatus(handler) +} + +// IsAvailableAsHRESULT returns true if e may be converted to an HRESULT. +func (e Error) IsAvailableAsHRESULT() bool { + return true +} + +// IsAvailableAsErrno returns true if e may be converted to a windows.Errno. +func (e Error) IsAvailableAsErrno() bool { + hr := HRESULT(e) + if hr.isCustomer() || e.IsAvailableAsNTStatus() || (hr.facility() == facilityWin32) { + return true + } + _, convertable := commonHRESULTToErrno[hr] + return convertable +} + +// IsAvailableAsNTStatus returns true if e may be converted to a windows.NTStatus. +func (e Error) IsAvailableAsNTStatus() bool { + return HRESULT(e) == hrS_OK || HRESULT(e).isNT() +} + +// Error produces a human-readable message describing Error e. +func (e Error) Error() string { + if HRESULT(e).isCustomer() { + return windows.Errno(uint32(e) ^ hrFailBit).Error() + } + + buf := make([]uint16, 300) + const flags = windows.FORMAT_MESSAGE_FROM_SYSTEM | windows.FORMAT_MESSAGE_IGNORE_INSERTS + lenExclNul, err := windows.FormatMessage(flags, 0, uint32(e), 0, buf, nil) + if err != nil { + return fmt.Sprintf("wingoes.Error 0x%08X", uint32(e)) + } + for ; lenExclNul > 0 && (buf[lenExclNul-1] == '\n' || buf[lenExclNul-1] == '\r'); lenExclNul-- { + } + return windows.UTF16ToString(buf[:lenExclNul]) +} diff --git a/vendor/github.com/dblohm7/wingoes/error_notwindows.go b/vendor/github.com/dblohm7/wingoes/error_notwindows.go new file mode 100644 index 0000000000..b9237766a7 --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/error_notwindows.go @@ -0,0 +1,9 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package wingoes + +// HRESULT is equivalent to the HRESULT type in the Win32 SDK for C/C++. +type HRESULT int32 diff --git a/vendor/github.com/dblohm7/wingoes/guid.go b/vendor/github.com/dblohm7/wingoes/guid.go new file mode 100644 index 0000000000..f9436ebb7d --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/guid.go @@ -0,0 +1,24 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows + +package wingoes + +import ( + "fmt" + + "golang.org/x/sys/windows" +) + +// MustGetGUID parses s, a string containing a GUID and returns a pointer to the +// parsed GUID. s must be specified in the format "{XXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX}". +// If there is an error parsing s, MustGetGUID panics. +func MustGetGUID(s string) *windows.GUID { + guid, err := windows.GUIDFromString(s) + if err != nil { + panic(fmt.Sprintf("wingoes.MustGetGUID(%q) error %v", s, err)) + } + return &guid +} diff --git a/vendor/github.com/dblohm7/wingoes/internal/types.go b/vendor/github.com/dblohm7/wingoes/internal/types.go new file mode 100644 index 0000000000..62fe249f16 --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/internal/types.go @@ -0,0 +1,13 @@ +// Copyright (c) 2023 Tailscale Inc & AUTHORS. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows + +package internal + +import ( + "golang.org/x/sys/windows" +) + +type HGLOBAL windows.Handle diff --git a/vendor/github.com/dblohm7/wingoes/osversion.go b/vendor/github.com/dblohm7/wingoes/osversion.go new file mode 100644 index 0000000000..07ebcc4aab --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/osversion.go @@ -0,0 +1,193 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows + +package wingoes + +import ( + "fmt" + "sync" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" +) + +var ( + verOnce sync.Once + verInfo osVersionInfo // must access via getVersionInfo() +) + +// osVersionInfo is more compact than windows.OsVersionInfoEx, which contains +// extraneous information. +type osVersionInfo struct { + major uint32 + minor uint32 + build uint32 + servicePack uint16 + str string + isDC bool + isServer bool +} + +const ( + _VER_NT_WORKSTATION = 1 + _VER_NT_DOMAIN_CONTROLLER = 2 + _VER_NT_SERVER = 3 +) + +func getVersionInfo() *osVersionInfo { + verOnce.Do(func() { + osv := windows.RtlGetVersion() + verInfo = osVersionInfo{ + major: osv.MajorVersion, + minor: osv.MinorVersion, + build: osv.BuildNumber, + servicePack: osv.ServicePackMajor, + str: fmt.Sprintf("%d.%d.%d", osv.MajorVersion, osv.MinorVersion, osv.BuildNumber), + isDC: osv.ProductType == _VER_NT_DOMAIN_CONTROLLER, + // Domain Controllers are also implicitly servers. + isServer: osv.ProductType == _VER_NT_DOMAIN_CONTROLLER || osv.ProductType == _VER_NT_SERVER, + } + // UBR is only available on Windows 10 and 11 (MajorVersion == 10). + if osv.MajorVersion == 10 { + if ubr, err := getUBR(); err == nil { + verInfo.str = fmt.Sprintf("%s.%d", verInfo.str, ubr) + } + } + }) + return &verInfo +} + +// getUBR returns the "update build revision," ie. the fourth component of the +// version string found on Windows 10 and Windows 11 systems. +func getUBR() (uint32, error) { + key, err := registry.OpenKey(registry.LOCAL_MACHINE, + `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE|registry.WOW64_64KEY) + if err != nil { + return 0, err + } + defer key.Close() + + val, valType, err := key.GetIntegerValue("UBR") + if err != nil { + return 0, err + } + if valType != registry.DWORD { + return 0, registry.ErrUnexpectedType + } + + return uint32(val), nil +} + +// GetOSVersionString returns the Windows version of the current machine in +// dotted-decimal form. The version string contains 3 components on Windows 7 +// and 8.x, and 4 components on Windows 10 and 11. +func GetOSVersionString() string { + return getVersionInfo().String() +} + +// IsWinServer returns true if and only if this computer's version of Windows is +// a server edition. +func IsWinServer() bool { + return getVersionInfo().isServer +} + +// IsWinDomainController returs true if this computer's version of Windows is +// configured to act as a domain controller. +func IsWinDomainController() bool { + return getVersionInfo().isDC +} + +// IsWin7SP1OrGreater returns true when running on Windows 7 SP1 or newer. +func IsWin7SP1OrGreater() bool { + if IsWin8OrGreater() { + return true + } + + vi := getVersionInfo() + return vi.major == 6 && vi.minor == 1 && vi.servicePack > 0 +} + +// IsWin8OrGreater returns true when running on Windows 8.0 or newer. +func IsWin8OrGreater() bool { + return getVersionInfo().isVersionOrGreater(6, 2, 0) +} + +// IsWin8Point1OrGreater returns true when running on Windows 8.1 or newer. +func IsWin8Point1OrGreater() bool { + return getVersionInfo().isVersionOrGreater(6, 3, 0) +} + +// IsWin10OrGreater returns true when running on any build of Windows 10 or newer. +func IsWin10OrGreater() bool { + return getVersionInfo().major >= 10 +} + +// Win10BuildConstant encodes build numbers for the various editions of Windows 10, +// for use with IsWin10BuildOrGreater. +type Win10BuildConstant uint32 + +const ( + Win10BuildNov2015 = Win10BuildConstant(10586) + Win10BuildAnniversary = Win10BuildConstant(14393) + Win10BuildCreators = Win10BuildConstant(15063) + Win10BuildFallCreators = Win10BuildConstant(16299) + Win10BuildApr2018 = Win10BuildConstant(17134) + Win10BuildSep2018 = Win10BuildConstant(17763) + Win10BuildMay2019 = Win10BuildConstant(18362) + Win10BuildSep2019 = Win10BuildConstant(18363) + Win10BuildApr2020 = Win10BuildConstant(19041) + Win10Build20H2 = Win10BuildConstant(19042) + Win10Build21H1 = Win10BuildConstant(19043) + Win10Build21H2 = Win10BuildConstant(19044) +) + +// IsWin10BuildOrGreater returns true when running on the specified Windows 10 +// build, or newer. +func IsWin10BuildOrGreater(build Win10BuildConstant) bool { + return getVersionInfo().isWin10BuildOrGreater(uint32(build)) +} + +// Win11BuildConstant encodes build numbers for the various editions of Windows 11, +// for use with IsWin11BuildOrGreater. +type Win11BuildConstant uint32 + +const ( + Win11BuildRTM = Win11BuildConstant(22000) + Win11Build22H2 = Win11BuildConstant(22621) +) + +// IsWin11OrGreater returns true when running on any release of Windows 11, +// or newer. +func IsWin11OrGreater() bool { + return IsWin11BuildOrGreater(Win11BuildRTM) +} + +// IsWin11BuildOrGreater returns true when running on the specified Windows 11 +// build, or newer. +func IsWin11BuildOrGreater(build Win11BuildConstant) bool { + // Under the hood, Windows 11 is just Windows 10 with a sufficiently advanced + // build number. + return getVersionInfo().isWin10BuildOrGreater(uint32(build)) +} + +func (osv *osVersionInfo) String() string { + return osv.str +} + +func (osv *osVersionInfo) isWin10BuildOrGreater(build uint32) bool { + return osv.isVersionOrGreater(10, 0, build) +} + +func (osv *osVersionInfo) isVersionOrGreater(major, minor, build uint32) bool { + return isVerGE(osv.major, major, osv.minor, minor, osv.build, build) +} + +func isVerGE(lmajor, rmajor, lminor, rminor, lbuild, rbuild uint32) bool { + return lmajor > rmajor || + lmajor == rmajor && + (lminor > rminor || + lminor == rminor && lbuild >= rbuild) +} diff --git a/vendor/github.com/dblohm7/wingoes/pe/pe.go b/vendor/github.com/dblohm7/wingoes/pe/pe.go new file mode 100644 index 0000000000..78a7da11ee --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/pe/pe.go @@ -0,0 +1,784 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +// Package pe provides facilities for extracting information from PE binaries. +package pe + +import ( + "bufio" + "bytes" + dpe "debug/pe" + "encoding/binary" + "errors" + "fmt" + "io" + "math/bits" + "os" + "reflect" + "strings" + "unsafe" + + "golang.org/x/exp/constraints" + "golang.org/x/sys/windows" +) + +// The following constants are from the PE spec +const ( + offsetIMAGE_DOS_HEADERe_lfanew = 0x3C + maxNumSections = 96 +) + +var ( + ErrBadLength = errors.New("effective length did not match expected length") + ErrBadCodeView = errors.New("invalid CodeView debug info") + ErrNotCodeView = errors.New("debug info is not CodeView") + ErrNotPresent = errors.New("not present in this PE image") + ErrIndexOutOfRange = errors.New("index out of range") + // ErrInvalidBinary is returned whenever the headers do not parse as expected, + // or reference locations outside the bounds of the PE file or module. + // The headers might be corrupt, malicious, or have been tampered with. + ErrInvalidBinary = errors.New("invalid PE binary") + ErrResolvingFileRVA = errors.New("could not resolve file RVA") + ErrUnavailableInModule = errors.New("this information is unavailable from loaded modules; the PE file itself must be examined") + ErrUnsupportedMachine = errors.New("unsupported machine") +) + +type peReader interface { + Base() uintptr + io.Closer + io.ReaderAt + io.ReadSeeker + Limit() uintptr +} + +// PEHeaders represents the partially-parsed headers from a PE binary. +type PEHeaders struct { + r peReader + fileHeader *dpe.FileHeader + optionalHeader *optionalHeader + sections []peSectionHeader +} + +type peBounds struct { + base uintptr + limit uintptr +} + +type peFile struct { + *os.File + peBounds +} + +func (pef *peFile) Base() uintptr { + return pef.peBounds.base +} + +func (pef *peFile) Limit() uintptr { + if pef.limit == 0 { + if fi, err := pef.Stat(); err == nil { + pef.limit = uintptr(fi.Size()) + } + } + return pef.limit +} + +type peModule struct { + *bytes.Reader + peBounds + modLock windows.Handle +} + +func (pei *peModule) Base() uintptr { + return pei.peBounds.base +} + +func (pei *peModule) Close() error { + return windows.FreeLibrary(pei.modLock) +} + +func (pei *peModule) Limit() uintptr { + return pei.peBounds.limit +} + +// NewPEFromBaseAddressAndSize parses the headers in a PE binary loaded +// into the current process's address space at address baseAddr with known +// size. If you do not have the size, use NewPEFromBaseAddress instead. +// Upon success it returns a non-nil *PEHeaders, otherwise it returns a nil +// *PEHeaders and a non-nil error. +func NewPEFromBaseAddressAndSize(baseAddr uintptr, size uint32) (*PEHeaders, error) { + // Grab a strong reference to the module until we're done with it. + var modLock windows.Handle + if err := windows.GetModuleHandleEx( + windows.GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS, + (*uint16)(unsafe.Pointer(baseAddr)), + &modLock, + ); err != nil { + return nil, err + } + + slc := unsafe.Slice((*byte)(unsafe.Pointer(baseAddr)), size) + r := bytes.NewReader(slc) + peMod := &peModule{ + Reader: r, + peBounds: peBounds{ + base: baseAddr, + limit: baseAddr + uintptr(size), + }, + modLock: modLock, + } + return loadHeaders(peMod) +} + +// NewPEFromBaseAddress parses the headers in a PE binary loaded into the +// current process's address space at address baseAddr. +// Upon success it returns a non-nil *PEHeaders, otherwise it returns a nil +// *PEHeaders and a non-nil error. +func NewPEFromBaseAddress(baseAddr uintptr) (*PEHeaders, error) { + var modInfo windows.ModuleInfo + if err := windows.GetModuleInformation( + windows.CurrentProcess(), + windows.Handle(baseAddr), + &modInfo, + uint32(unsafe.Sizeof(modInfo)), + ); err != nil { + return nil, fmt.Errorf("querying module handle: %w", err) + } + + return NewPEFromBaseAddressAndSize(baseAddr, modInfo.SizeOfImage) +} + +// NewPEFromHMODULE parses the headers in a PE binary identified by hmodule that +// is currently loaded into the current process's address space. +// Upon success it returns a non-nil *PEHeaders, otherwise it returns a nil +// *PEHeaders and a non-nil error. +func NewPEFromHMODULE(hmodule windows.Handle) (*PEHeaders, error) { + // HMODULEs are just a loaded module's base address with the lowest two + // bits used for flags (see docs for LoadLibraryExW). + return NewPEFromBaseAddress(uintptr(hmodule) & ^uintptr(3)) +} + +// NewPEFromDLL parses the headers in a PE binary identified by dll that +// is currently loaded into the current process's address space. +// Upon success it returns a non-nil *PEHeaders, otherwise it returns a nil +// *PEHeaders and a non-nil error. +// If the DLL is Release()d while the returned *PEHeaders is still in use, +// its behaviour will become undefined. +func NewPEFromDLL(dll *windows.DLL) (*PEHeaders, error) { + if dll == nil || dll.Handle == 0 { + return nil, os.ErrInvalid + } + + return NewPEFromHMODULE(dll.Handle) +} + +// NewPEFromLazyDLL parses the headers in a PE binary identified by ldll that +// is currently loaded into the current process's address space. +// Upon success it returns a non-nil *PEHeaders, otherwise it returns a nil +// *PEHeaders and a non-nil error. +func NewPEFromLazyDLL(ldll *windows.LazyDLL) (*PEHeaders, error) { + if ldll == nil { + return nil, os.ErrInvalid + } + if err := ldll.Load(); err != nil { + return nil, err + } + + return NewPEFromHMODULE(windows.Handle(ldll.Handle())) +} + +// NewPEFromFileName opens a PE binary located at filename and parses its PE +// headers. Upon success it returns a non-nil *PEHeaders, otherwise it returns a +// nil *PEHeaders and a non-nil error. +// Call Close() on the returned *PEHeaders when it is no longer needed. +func NewPEFromFileName(filename string) (*PEHeaders, error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + + return newPEFromFile(f) +} + +// NewPEFromFileHandle parses the PE headers from hfile, an open Win32 file handle. +// It does *not* consume hfile. +// Upon success it returns a non-nil *PEHeaders, otherwise it returns a +// nil *PEHeaders and a non-nil error. +// Call Close() on the returned *PEHeaders when it is no longer needed. +func NewPEFromFileHandle(hfile windows.Handle) (*PEHeaders, error) { + if hfile == 0 || hfile == windows.InvalidHandle { + return nil, os.ErrInvalid + } + + // Duplicate hfile so that we don't consume it. + var hfileDup windows.Handle + cp := windows.CurrentProcess() + if err := windows.DuplicateHandle( + cp, + hfile, + cp, + &hfileDup, + 0, + false, + windows.DUPLICATE_SAME_ACCESS, + ); err != nil { + return nil, err + } + + return newPEFromFile(os.NewFile(uintptr(hfileDup), "PEFromFileHandle")) +} + +func newPEFromFile(f *os.File) (*PEHeaders, error) { + // peBounds base is 0, limit is loaded lazily + pef := &peFile{File: f} + return loadHeaders(pef) +} + +func (peh *PEHeaders) Close() error { + return peh.r.Close() +} + +type rvaType interface { + ~int8 | ~int16 | ~int32 | ~uint8 | ~uint16 | ~uint32 +} + +// addOffset ensures that, if off is negative, it does not underflow base. +func addOffset[O rvaType](base uintptr, off O) uintptr { + if off >= 0 { + return base + uintptr(off) + } + + negation := uintptr(-off) + if negation >= base { + return 0 + } + return base - negation +} + +func binaryRead(r io.Reader, data any) (err error) { + // Windows is always LittleEndian + err = binary.Read(r, binary.LittleEndian, data) + if err == io.ErrUnexpectedEOF { + err = ErrBadLength + } + return err +} + +// readStruct reads a T from offset rva. If r is a *peModule, the returned *T +// points to the data in-place. +// Note that currently this function will fail if rva references memory beyond +// the bounds of the binary; in the case of modules, this may need to be relaxed +// in some cases due to tampering by third-party crapware. +func readStruct[T any, O rvaType](r peReader, rva O) (*T, error) { + switch v := r.(type) { + case *peFile: + if _, err := r.Seek(int64(rva), io.SeekStart); err != nil { + return nil, err + } + + result := new(T) + if err := binaryRead(r, result); err != nil { + return nil, err + } + + return result, nil + case *peModule: + addr := addOffset(r.Base(), rva) + szT := unsafe.Sizeof(*((*T)(nil))) + if addr+szT >= v.Limit() { + return nil, ErrInvalidBinary + } + + return (*T)(unsafe.Pointer(addr)), nil + default: + return nil, os.ErrInvalid + } +} + +// readStructArray reads a []T with length count from offset rva. If r is a +// *peModule, the returned []T references the data in-place. +// Note that currently this function will fail if rva references memory beyond +// the bounds of the binary; in the case of modules, this may need to be relaxed +// in some cases due to tampering by third-party crapware. +func readStructArray[T any, O rvaType](r peReader, rva O, count int) ([]T, error) { + switch v := r.(type) { + case *peFile: + if _, err := r.Seek(int64(rva), io.SeekStart); err != nil { + return nil, err + } + + result := make([]T, count) + if err := binaryRead(r, result); err != nil { + return nil, err + } + + return result, nil + case *peModule: + addr := addOffset(r.Base(), rva) + szT := reflect.ArrayOf(count, reflect.TypeOf((*T)(nil)).Elem()).Size() + if addr+szT >= v.Limit() { + return nil, ErrInvalidBinary + } + + return unsafe.Slice((*T)(unsafe.Pointer(addr)), count), nil + default: + return nil, os.ErrInvalid + } +} + +type peSectionHeader dpe.SectionHeader32 + +func (s *peSectionHeader) NameAsString() string { + // s.Name is UTF-8. When the string's length is < len(s.Name), the remaining + // bytes are padded with zeros. + for i, c := range s.Name { + if c == 0 { + return string(s.Name[:i]) + } + } + + return string(s.Name[:]) +} + +func loadHeaders(r peReader) (*PEHeaders, error) { + // Check the signature of the DOS stub header + var mz [2]byte + if _, err := r.ReadAt(mz[:], 0); err != nil { + if err == io.EOF { + err = ErrInvalidBinary + } + return nil, err + } + if mz[0] != 'M' || mz[1] != 'Z' { + return nil, ErrInvalidBinary + } + + // Seek to the offset of the value that points to the beginning of the PE headers + if _, err := r.Seek(offsetIMAGE_DOS_HEADERe_lfanew, io.SeekStart); err != nil { + return nil, err + } + + // Load the offset to the beginning of the PE headers + var e_lfanew int32 + if err := binaryRead(r, &e_lfanew); err != nil { + if err == ErrBadLength { + err = ErrInvalidBinary + } + return nil, err + } + if e_lfanew <= 0 { + return nil, ErrInvalidBinary + } + if addOffset(r.Base(), e_lfanew) >= r.Limit() { + return nil, ErrInvalidBinary + } + + // Check the PE signature + var peSig [4]byte + if _, err := r.ReadAt(peSig[:], int64(e_lfanew)); err != nil { + if err == io.EOF { + err = ErrInvalidBinary + } + return nil, err + } + if peSig[0] != 'P' || peSig[1] != 'E' || peSig[2] != 0 || peSig[3] != 0 { + return nil, ErrInvalidBinary + } + + // Read the file header + fileHeaderOffset := uint32(e_lfanew) + uint32(unsafe.Sizeof(peSig)) + if r.Base()+uintptr(fileHeaderOffset) >= r.Limit() { + return nil, ErrInvalidBinary + } + + fileHeader, err := readStruct[dpe.FileHeader](r, fileHeaderOffset) + if err != nil { + return nil, err + } + + // In-memory modules should always have a machine type that matches our own. + // (okay, so that's kinda sorta untrue with respect to WOW64, but that's + // a _very_ obscure use case). + _, isModule := r.(*peModule) + // TODO(aaron): Uncomment once we can read binaries from disk whose archs + // do not necessarily match our own. + if /*isModule &&*/ fileHeader.Machine != expectedMachine { + return nil, ErrUnsupportedMachine + } + + // Read the optional header + optionalHeaderOffset := uint32(fileHeaderOffset) + uint32(unsafe.Sizeof(*fileHeader)) + if r.Base()+uintptr(optionalHeaderOffset) >= r.Limit() { + return nil, ErrInvalidBinary + } + + // TODO(aaron): parameterize optional header type so we can read binaries + // from disk whose archs do not necessarily match our own. + optionalHeader, err := readStruct[optionalHeader](r, optionalHeaderOffset) + if err != nil { + return nil, err + } + + // Check the optional header's Magic field + expectedOptionalHeaderMagic := uint16(optionalHeaderMagic) + if !isModule { + switch fileHeader.Machine { + case dpe.IMAGE_FILE_MACHINE_I386: + expectedOptionalHeaderMagic = 0x010B + case dpe.IMAGE_FILE_MACHINE_AMD64, dpe.IMAGE_FILE_MACHINE_ARM64: + expectedOptionalHeaderMagic = 0x020B + default: + return nil, ErrUnsupportedMachine + } + } + + if optionalHeader.Magic != expectedOptionalHeaderMagic { + return nil, ErrInvalidBinary + } + + // Coarse-grained check that header sizes make sense + totalEssentialHeaderLen := uint32(offsetIMAGE_DOS_HEADERe_lfanew) + + uint32(unsafe.Sizeof(e_lfanew)) + + uint32(unsafe.Sizeof(*fileHeader)) + + uint32(fileHeader.SizeOfOptionalHeader) + if optionalHeader.SizeOfImage < totalEssentialHeaderLen { + return nil, ErrInvalidBinary + } + + numSections := fileHeader.NumberOfSections + if numSections > maxNumSections { + // More than 96 sections?! Really?! + return nil, ErrInvalidBinary + } + + // Read in the section table + sectionTableOffset := optionalHeaderOffset + uint32(fileHeader.SizeOfOptionalHeader) + if r.Base()+uintptr(sectionTableOffset) >= r.Limit() { + return nil, ErrInvalidBinary + } + + sections, err := readStructArray[peSectionHeader](r, sectionTableOffset, int(numSections)) + if err != nil { + return nil, err + } + + return &PEHeaders{r: r, fileHeader: fileHeader, optionalHeader: optionalHeader, sections: sections}, nil +} + +type rva32 interface { + ~int32 | ~uint32 +} + +// resolveRVA resolves rva, or returns 0 if unavailable. +func resolveRVA[O rva32](nfo *PEHeaders, rva O) O { + if _, ok := nfo.r.(*peFile); !ok { + // Just the identity function in this case. + return rva + } + + if rva <= 0 { + return 0 + } + + // We walk the section table, locating the section that would contain rva if + // we were mapped into memory. We then calculate the offset of rva from the + // starting virtual address of the section, and then add that offset to the + // section's starting file pointer. + urva := uint32(rva) + for _, s := range nfo.sections { + if urva < s.VirtualAddress { + continue + } + if urva >= (s.VirtualAddress + s.VirtualSize) { + continue + } + voff := urva - s.VirtualAddress + foff := s.PointerToRawData + voff + if foff >= s.PointerToRawData+s.SizeOfRawData { + return 0 + } + return O(foff) + } + + return 0 +} + +type DataDirectoryEntry = dpe.DataDirectory + +func (nfo *PEHeaders) dataDirectory() []DataDirectoryEntry { + cnt := nfo.optionalHeader.NumberOfRvaAndSizes + if maxCnt := uint32(len(nfo.optionalHeader.DataDirectory)); cnt > maxCnt { + cnt = maxCnt + } + return nfo.optionalHeader.DataDirectory[:cnt] +} + +// DataDirectoryIndex is an enumeration specifying a particular entry in the +// data directory. +type DataDirectoryIndex int + +const ( + IMAGE_DIRECTORY_ENTRY_EXPORT = DataDirectoryIndex(dpe.IMAGE_DIRECTORY_ENTRY_EXPORT) + IMAGE_DIRECTORY_ENTRY_IMPORT = DataDirectoryIndex(dpe.IMAGE_DIRECTORY_ENTRY_IMPORT) + IMAGE_DIRECTORY_ENTRY_RESOURCE = DataDirectoryIndex(dpe.IMAGE_DIRECTORY_ENTRY_RESOURCE) + IMAGE_DIRECTORY_ENTRY_EXCEPTION = DataDirectoryIndex(dpe.IMAGE_DIRECTORY_ENTRY_EXCEPTION) + IMAGE_DIRECTORY_ENTRY_SECURITY = DataDirectoryIndex(dpe.IMAGE_DIRECTORY_ENTRY_SECURITY) + IMAGE_DIRECTORY_ENTRY_BASERELOC = DataDirectoryIndex(dpe.IMAGE_DIRECTORY_ENTRY_BASERELOC) + IMAGE_DIRECTORY_ENTRY_DEBUG = DataDirectoryIndex(dpe.IMAGE_DIRECTORY_ENTRY_DEBUG) + IMAGE_DIRECTORY_ENTRY_ARCHITECTURE = DataDirectoryIndex(dpe.IMAGE_DIRECTORY_ENTRY_ARCHITECTURE) + IMAGE_DIRECTORY_ENTRY_GLOBALPTR = DataDirectoryIndex(dpe.IMAGE_DIRECTORY_ENTRY_GLOBALPTR) + IMAGE_DIRECTORY_ENTRY_TLS = DataDirectoryIndex(dpe.IMAGE_DIRECTORY_ENTRY_TLS) + IMAGE_DIRECTORY_ENTRY_LOAD_CONFIG = DataDirectoryIndex(dpe.IMAGE_DIRECTORY_ENTRY_LOAD_CONFIG) + IMAGE_DIRECTORY_ENTRY_BOUND_IMPORT = DataDirectoryIndex(dpe.IMAGE_DIRECTORY_ENTRY_BOUND_IMPORT) + IMAGE_DIRECTORY_ENTRY_IAT = DataDirectoryIndex(dpe.IMAGE_DIRECTORY_ENTRY_IAT) + IMAGE_DIRECTORY_ENTRY_DELAY_IMPORT = DataDirectoryIndex(dpe.IMAGE_DIRECTORY_ENTRY_DELAY_IMPORT) + IMAGE_DIRECTORY_ENTRY_COM_DESCRIPTOR = DataDirectoryIndex(dpe.IMAGE_DIRECTORY_ENTRY_COM_DESCRIPTOR) +) + +// DataDirectoryEntry returns information from nfo's data directory at index idx. +// The type of the return value depends on the value of idx. Most values for idx +// currently return the DataDirectoryEntry itself, however it will return more +// sophisticated information for the following values of idx: +// +// IMAGE_DIRECTORY_ENTRY_SECURITY returns []AuthenticodeCert +// IMAGE_DIRECTORY_ENTRY_DEBUG returns []IMAGE_DEBUG_DIRECTORY +// +// Note that other idx values _will_ be modified in the future to support more +// sophisticated return values, so be careful to structure your type assertions +// accordingly. +func (nfo *PEHeaders) DataDirectoryEntry(idx DataDirectoryIndex) (any, error) { + dd := nfo.dataDirectory() + if int(idx) >= len(dd) { + return nil, ErrIndexOutOfRange + } + + dde := dd[idx] + if dde.VirtualAddress == 0 || dde.Size == 0 { + return nil, ErrNotPresent + } + + switch idx { + /* TODO(aaron): (don't forget to sync tests!) + case IMAGE_DIRECTORY_ENTRY_EXPORT: + case IMAGE_DIRECTORY_ENTRY_IMPORT: + case IMAGE_DIRECTORY_ENTRY_RESOURCE: + */ + case IMAGE_DIRECTORY_ENTRY_SECURITY: + return nfo.extractAuthenticode(dde) + case IMAGE_DIRECTORY_ENTRY_DEBUG: + return nfo.extractDebugInfo(dde) + // case IMAGE_DIRECTORY_ENTRY_COM_DESCRIPTOR: + default: + return dde, nil + } +} + +// WIN_CERT_REVISION is an enumeration from the Windows SDK. +type WIN_CERT_REVISION uint16 + +const ( + WIN_CERT_REVISION_1_0 WIN_CERT_REVISION = 0x0100 + WIN_CERT_REVISION_2_0 WIN_CERT_REVISION = 0x0200 +) + +// WIN_CERT_TYPE is an enumeration from the Windows SDK. +type WIN_CERT_TYPE uint16 + +const ( + WIN_CERT_TYPE_X509 WIN_CERT_TYPE = 0x0001 + WIN_CERT_TYPE_PKCS_SIGNED_DATA WIN_CERT_TYPE = 0x0002 + WIN_CERT_TYPE_TS_STACK_SIGNED WIN_CERT_TYPE = 0x0004 +) + +type _WIN_CERTIFICATE_HEADER struct { + Length uint32 + Revision WIN_CERT_REVISION + CertificateType WIN_CERT_TYPE +} + +// AuthenticodeCert represents an authenticode signature that has been extracted +// from a signed PE binary but not fully parsed. +type AuthenticodeCert struct { + header _WIN_CERTIFICATE_HEADER + data []byte +} + +// Revision returns the revision of ac. +func (ac *AuthenticodeCert) Revision() WIN_CERT_REVISION { + return ac.header.Revision +} + +// Type returns the type of ac. +func (ac *AuthenticodeCert) Type() WIN_CERT_TYPE { + return ac.header.CertificateType +} + +// Data returns the raw bytes of ac's cert. +func (ac *AuthenticodeCert) Data() []byte { + return ac.data +} + +func alignUp[V constraints.Integer](v V, powerOfTwo uint8) V { + if bits.OnesCount8(powerOfTwo) != 1 { + panic("invalid powerOfTwo argument to alignUp") + } + return v + ((-v) & (V(powerOfTwo) - 1)) +} + +// IMAGE_DEBUG_DIRECTORY describes debug information embedded in the binary. +type IMAGE_DEBUG_DIRECTORY struct { + Characteristics uint32 + TimeDateStamp uint32 + MajorVersion uint16 + MinorVersion uint16 + Type uint32 // an IMAGE_DEBUG_TYPE constant + SizeOfData uint32 + AddressOfRawData uint32 + PointerToRawData uint32 +} + +// IMAGE_DEBUG_TYPE_CODEVIEW identifies the current IMAGE_DEBUG_DIRECTORY as +// pointing to CodeView debug information. +const IMAGE_DEBUG_TYPE_CODEVIEW = 2 + +// IMAGE_DEBUG_INFO_CODEVIEW_UNPACKED contains CodeView debug information +// embedded in the PE file. Note that this structure's ABI does not match its C +// counterpart because the former uses a Go string and the latter is packed and +// also includes a signature field. +type IMAGE_DEBUG_INFO_CODEVIEW_UNPACKED struct { + GUID windows.GUID + Age uint32 + PDBPath string +} + +// String returns the data from u formatted in the same way that Microsoft +// debugging tools and symbol servers use to identify PDB files corresponding +// to a specific binary. +func (u *IMAGE_DEBUG_INFO_CODEVIEW_UNPACKED) String() string { + var b strings.Builder + fmt.Fprintf(&b, "%08X%04X%04X", u.GUID.Data1, u.GUID.Data2, u.GUID.Data3) + for _, v := range u.GUID.Data4 { + fmt.Fprintf(&b, "%02X", v) + } + fmt.Fprintf(&b, "%X", u.Age) + return b.String() +} + +const codeViewSignature = 0x53445352 + +func (u *IMAGE_DEBUG_INFO_CODEVIEW_UNPACKED) unpack(r *bufio.Reader) error { + var signature uint32 + if err := binaryRead(r, &signature); err != nil { + return err + } + if signature != codeViewSignature { + return ErrBadCodeView + } + + if err := binaryRead(r, &u.GUID); err != nil { + return err + } + + if err := binaryRead(r, &u.Age); err != nil { + return err + } + + pdbBytes := make([]byte, 0, 16) + for b, err := r.ReadByte(); err == nil && b != 0; b, err = r.ReadByte() { + pdbBytes = append(pdbBytes, b) + } + + u.PDBPath = string(pdbBytes) + return nil +} + +func (nfo *PEHeaders) extractDebugInfo(dde DataDirectoryEntry) (any, error) { + rva := resolveRVA(nfo, dde.VirtualAddress) + if rva == 0 { + return nil, ErrResolvingFileRVA + } + + count := dde.Size / uint32(unsafe.Sizeof(IMAGE_DEBUG_DIRECTORY{})) + return readStructArray[IMAGE_DEBUG_DIRECTORY](nfo.r, rva, int(count)) +} + +// ExtractCodeViewInfo obtains CodeView debug information from de, assuming that +// de represents CodeView debug info. +func (nfo *PEHeaders) ExtractCodeViewInfo(de IMAGE_DEBUG_DIRECTORY) (*IMAGE_DEBUG_INFO_CODEVIEW_UNPACKED, error) { + if de.Type != IMAGE_DEBUG_TYPE_CODEVIEW { + return nil, ErrNotCodeView + } + + cv := new(IMAGE_DEBUG_INFO_CODEVIEW_UNPACKED) + var sr *io.SectionReader + switch v := nfo.r.(type) { + case *peFile: + sr = io.NewSectionReader(v, int64(de.PointerToRawData), int64(de.SizeOfData)) + case *peModule: + sr = io.NewSectionReader(v, int64(de.AddressOfRawData), int64(de.SizeOfData)) + default: + return nil, ErrInvalidBinary + } + + if err := cv.unpack(bufio.NewReader(sr)); err != nil { + return nil, err + } + + return cv, nil +} + +func readFull(r io.Reader, buf []byte) (n int, err error) { + n, err = io.ReadFull(r, buf) + if err == io.ErrUnexpectedEOF { + err = ErrBadLength + } + return n, err +} + +func (nfo *PEHeaders) extractAuthenticode(dde DataDirectoryEntry) (any, error) { + if _, ok := nfo.r.(*peFile); !ok { + // Authenticode; only available in file, not loaded at runtime. + return nil, ErrUnavailableInModule + } + + var result []AuthenticodeCert + // The VirtualAddress is a file offset. + sr := io.NewSectionReader(nfo.r, int64(dde.VirtualAddress), int64(dde.Size)) + var curOffset int64 + szEntry := unsafe.Sizeof(_WIN_CERTIFICATE_HEADER{}) + + for { + var entry AuthenticodeCert + if err := binaryRead(sr, &entry.header); err != nil { + if err == io.EOF { + break + } + return nil, err + } + curOffset += int64(szEntry) + + if uintptr(entry.header.Length) < szEntry { + return nil, ErrInvalidBinary + } + + entry.data = make([]byte, uintptr(entry.header.Length)-szEntry) + n, err := readFull(sr, entry.data) + if err != nil { + return nil, err + } + curOffset += int64(n) + + result = append(result, entry) + + curOffset = alignUp(curOffset, 8) + if _, err := sr.Seek(curOffset, io.SeekStart); err != nil { + if err == io.EOF { + break + } + return nil, err + } + } + + return result, nil +} diff --git a/vendor/github.com/dblohm7/wingoes/pe/pe_386.go b/vendor/github.com/dblohm7/wingoes/pe/pe_386.go new file mode 100644 index 0000000000..4332a3a92a --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/pe/pe_386.go @@ -0,0 +1,18 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package pe + +import ( + dpe "debug/pe" +) + +type optionalHeader dpe.OptionalHeader32 +type ptrOffset int32 + +const ( + expectedMachine = dpe.IMAGE_FILE_MACHINE_I386 + optionalHeaderMagic = 0x010B +) diff --git a/vendor/github.com/dblohm7/wingoes/pe/pe_amd64.go b/vendor/github.com/dblohm7/wingoes/pe/pe_amd64.go new file mode 100644 index 0000000000..ce462e578a --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/pe/pe_amd64.go @@ -0,0 +1,18 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package pe + +import ( + dpe "debug/pe" +) + +type optionalHeader dpe.OptionalHeader64 +type ptrOffset int64 + +const ( + expectedMachine = dpe.IMAGE_FILE_MACHINE_AMD64 + optionalHeaderMagic = 0x020B +) diff --git a/vendor/github.com/dblohm7/wingoes/pe/pe_arm64.go b/vendor/github.com/dblohm7/wingoes/pe/pe_arm64.go new file mode 100644 index 0000000000..c9c310e30e --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/pe/pe_arm64.go @@ -0,0 +1,18 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package pe + +import ( + dpe "debug/pe" +) + +type optionalHeader dpe.OptionalHeader64 +type ptrOffset int64 + +const ( + expectedMachine = dpe.IMAGE_FILE_MACHINE_ARM64 + optionalHeaderMagic = 0x020B +) diff --git a/vendor/github.com/dblohm7/wingoes/pe/version.go b/vendor/github.com/dblohm7/wingoes/pe/version.go new file mode 100644 index 0000000000..8e8a6fd285 --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/pe/version.go @@ -0,0 +1,168 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package pe + +import ( + "errors" + "fmt" + "unsafe" + + "golang.org/x/sys/windows" +) + +var ( + errFixedFileInfoTooShort = errors.New("buffer smaller than VS_FIXEDFILEINFO") + errFixedFileInfoBadSig = errors.New("bad VS_FIXEDFILEINFO signature") +) + +// VersionNumber encapsulates a four-component version number that is stored +// in Windows VERSIONINFO resources. +type VersionNumber struct { + Major uint16 + Minor uint16 + Patch uint16 + Build uint16 +} + +func (vn VersionNumber) String() string { + return fmt.Sprintf("%d.%d.%d.%d", vn.Major, vn.Minor, vn.Patch, vn.Build) +} + +type langAndCodePage struct { + language uint16 + codePage uint16 +} + +// VersionInfo encapsulates a buffer containing the VERSIONINFO resources that +// have been successfully extracted from a PE binary. +type VersionInfo struct { + buf []byte + fixed *windows.VS_FIXEDFILEINFO + translationIDs []langAndCodePage +} + +const ( + langEnUS = 0x0409 + codePageUTF16LE = 0x04B0 + langNeutral = 0 + codePageNeutral = 0 +) + +// NewVersionInfo extracts any VERSIONINFO resource from filepath, parses its +// fixed-size information, and returns a *VersionInfo for further querying. +// It returns ErrNotPresent if no VERSIONINFO resources are found. +func NewVersionInfo(filepath string) (*VersionInfo, error) { + bufSize, err := windows.GetFileVersionInfoSize(filepath, nil) + if err != nil { + if errors.Is(err, windows.ERROR_RESOURCE_TYPE_NOT_FOUND) { + err = ErrNotPresent + } + return nil, err + } + + buf := make([]byte, bufSize) + if err := windows.GetFileVersionInfo(filepath, 0, bufSize, unsafe.Pointer(&buf[0])); err != nil { + return nil, err + } + + var fixed *windows.VS_FIXEDFILEINFO + var fixedLen uint32 + if err := windows.VerQueryValue(unsafe.Pointer(&buf[0]), `\`, unsafe.Pointer(&fixed), &fixedLen); err != nil { + return nil, err + } + if fixedLen < uint32(unsafe.Sizeof(windows.VS_FIXEDFILEINFO{})) { + return nil, errFixedFileInfoTooShort + } + if fixed.Signature != 0xFEEF04BD { + return nil, errFixedFileInfoBadSig + } + + return &VersionInfo{ + buf: buf, + fixed: fixed, + }, nil +} + +func (vi *VersionInfo) VersionNumber() VersionNumber { + f := vi.fixed + + return VersionNumber{ + Major: uint16(f.FileVersionMS >> 16), + Minor: uint16(f.FileVersionMS & 0xFFFF), + Patch: uint16(f.FileVersionLS >> 16), + Build: uint16(f.FileVersionLS & 0xFFFF), + } +} + +func (vi *VersionInfo) maybeLoadTranslationIDs() { + if vi.translationIDs != nil { + // Already loaded + return + } + + // Preferred translations, in order of preference. + preferredTranslationIDs := []langAndCodePage{ + langAndCodePage{ + language: langEnUS, + codePage: codePageUTF16LE, + }, + langAndCodePage{ + language: langNeutral, + codePage: codePageNeutral, + }, + } + + var ids *langAndCodePage + var idsNumBytes uint32 + if err := windows.VerQueryValue( + unsafe.Pointer(&vi.buf[0]), + `\VarFileInfo\Translation`, + unsafe.Pointer(&ids), + &idsNumBytes, + ); err != nil { + // If nothing is listed, then just try to use our preferred translation IDs. + vi.translationIDs = preferredTranslationIDs + return + } + + idsSlice := unsafe.Slice(ids, idsNumBytes/uint32(unsafe.Sizeof(*ids))) + vi.translationIDs = append(preferredTranslationIDs, idsSlice...) +} + +func (vi *VersionInfo) queryWithLangAndCodePage(key string, lcp langAndCodePage) (string, error) { + fq := fmt.Sprintf("\\StringFileInfo\\%04x%04x\\%s", lcp.language, lcp.codePage, key) + + var value *uint16 + var valueLen uint32 + if err := windows.VerQueryValue(unsafe.Pointer(&vi.buf[0]), fq, unsafe.Pointer(&value), &valueLen); err != nil { + return "", err + } + + return windows.UTF16ToString(unsafe.Slice(value, valueLen)), nil +} + +// Field queries the version information for a field named key and either +// returns the field's value, or an error. It attempts to resolve strings using +// the following order of language preference: en-US, language-neutral, followed +// by the first entry in version info's list of supported languages that +// successfully resolves the key. +// If the key cannot be resolved, it returns ErrNotPresent. +func (vi *VersionInfo) Field(key string) (string, error) { + vi.maybeLoadTranslationIDs() + + for _, lcp := range vi.translationIDs { + value, err := vi.queryWithLangAndCodePage(key, lcp) + if err == nil { + return value, nil + } + if !errors.Is(err, windows.ERROR_RESOURCE_TYPE_NOT_FOUND) { + return "", err + } + // Otherwise we continue looping and try the next language + } + + return "", ErrNotPresent +} diff --git a/vendor/github.com/dblohm7/wingoes/time.go b/vendor/github.com/dblohm7/wingoes/time.go new file mode 100644 index 0000000000..e29a7f1b9e --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/time.go @@ -0,0 +1,29 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows + +package wingoes + +import ( + "errors" + "time" + + "golang.org/x/sys/windows" +) + +var ( + // ErrDurationOutOfRange means that a time.Duration is too large to be able + // to be specified as a valid Win32 timeout value. + ErrDurationOutOfRange = errors.New("duration is out of timeout range") +) + +// DurationToTimeoutMilliseconds converts d into a timeout usable by Win32 APIs. +func DurationToTimeoutMilliseconds(d time.Duration) (uint32, error) { + millis := d.Milliseconds() + if millis >= windows.INFINITE { + return 0, ErrDurationOutOfRange + } + return uint32(millis), nil +} diff --git a/vendor/github.com/dblohm7/wingoes/util.go b/vendor/github.com/dblohm7/wingoes/util.go new file mode 100644 index 0000000000..464bd3f2bb --- /dev/null +++ b/vendor/github.com/dblohm7/wingoes/util.go @@ -0,0 +1,70 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows + +package wingoes + +import ( + "unsafe" + + "golang.org/x/sys/windows" +) + +// UserSIDs contains pointers to the SIDs for a user and their primary group. +type UserSIDs struct { + User *windows.SID + PrimaryGroup *windows.SID +} + +// CurrentProcessUserSIDs returns a UserSIDs containing the SIDs of the user +// and primary group who own the current process. +func CurrentProcessUserSIDs() (*UserSIDs, error) { + token, err := windows.OpenCurrentProcessToken() + if err != nil { + return nil, err + } + defer token.Close() + + userInfo, err := getTokenInfo[windows.Tokenuser](token, windows.TokenUser) + if err != nil { + return nil, err + } + + primaryGroup, err := getTokenInfo[windows.Tokenprimarygroup](token, windows.TokenPrimaryGroup) + if err != nil { + return nil, err + } + + // We just want the SIDs, not the rest of the structs that were output. + userSid, err := userInfo.User.Sid.Copy() + if err != nil { + return nil, err + } + + primaryGroupSid, err := primaryGroup.PrimaryGroup.Copy() + if err != nil { + return nil, err + } + + return &UserSIDs{User: userSid, PrimaryGroup: primaryGroupSid}, nil +} + +func getTokenInfo[T any](token windows.Token, infoClass uint32) (*T, error) { + var buf []byte + var desiredLen uint32 + + err := windows.GetTokenInformation(token, infoClass, nil, 0, &desiredLen) + + for err != nil { + if err != windows.ERROR_INSUFFICIENT_BUFFER { + return nil, err + } + + buf = make([]byte, desiredLen) + err = windows.GetTokenInformation(token, infoClass, &buf[0], desiredLen, &desiredLen) + } + + return (*T)(unsafe.Pointer(&buf[0])), nil +} diff --git a/vendor/github.com/godbus/dbus/v5/.cirrus.yml b/vendor/github.com/godbus/dbus/v5/.cirrus.yml new file mode 100644 index 0000000000..4e900f86d9 --- /dev/null +++ b/vendor/github.com/godbus/dbus/v5/.cirrus.yml @@ -0,0 +1,10 @@ +freebsd_instance: + image_family: freebsd-13-0 + +task: + name: Test on FreeBSD + install_script: pkg install -y go119 dbus + test_script: | + /usr/local/etc/rc.d/dbus onestart && \ + eval `dbus-launch --sh-syntax` && \ + go119 test -v ./... diff --git a/vendor/github.com/godbus/dbus/v5/.golangci.yml b/vendor/github.com/godbus/dbus/v5/.golangci.yml new file mode 100644 index 0000000000..f2d7910d4c --- /dev/null +++ b/vendor/github.com/godbus/dbus/v5/.golangci.yml @@ -0,0 +1,7 @@ +# For documentation, see https://golangci-lint.run/usage/configuration/ + +linters: + enable: + - gofumpt + - unconvert + - unparam diff --git a/vendor/github.com/godbus/dbus/v5/README.md b/vendor/github.com/godbus/dbus/v5/README.md index 5c24125838..5c6b19655c 100644 --- a/vendor/github.com/godbus/dbus/v5/README.md +++ b/vendor/github.com/godbus/dbus/v5/README.md @@ -23,7 +23,7 @@ go get github.com/godbus/dbus/v5 ### Usage The complete package documentation and some simple examples are available at -[godoc.org](http://godoc.org/github.com/godbus/dbus). Also, the +[pkg.go.dev](https://pkg.go.dev/github.com/godbus/dbus/v5). Also, the [_examples](https://github.com/godbus/dbus/tree/master/_examples) directory gives a short overview over the basic usage. @@ -34,6 +34,7 @@ gives a short overview over the basic usage. - [iwd](https://github.com/shibumi/iwd) go bindings for the internet wireless daemon "iwd". - [notify](https://github.com/esiqveland/notify) provides desktop notifications over dbus into a library. - [playerbm](https://github.com/altdesktop/playerbm) a bookmark utility for media players. +- [rpic](https://github.com/stephenhu/rpic) lightweight web app and RESTful API for managing a Raspberry Pi Please note that the API is considered unstable for now and may change without further notice. diff --git a/vendor/github.com/godbus/dbus/v5/SECURITY.md b/vendor/github.com/godbus/dbus/v5/SECURITY.md new file mode 100644 index 0000000000..7d262fbbfc --- /dev/null +++ b/vendor/github.com/godbus/dbus/v5/SECURITY.md @@ -0,0 +1,13 @@ +# Security Policy + +## Supported Versions + +Security updates are applied only to the latest release. + +## Reporting a Vulnerability + +If you have discovered a security vulnerability in this project, please report it privately. **Do not disclose it as a public issue.** This gives us time to work with you to fix the issue before public exposure, reducing the chance that the exploit will be used before a patch is released. + +Please disclose it at [security advisory](https://github.com/godbus/dbus/security/advisories/new). + +This project is maintained by a team of volunteers on a reasonable-effort basis. As such, vulnerabilities will be disclosed in a best effort base. diff --git a/vendor/github.com/godbus/dbus/v5/auth.go b/vendor/github.com/godbus/dbus/v5/auth.go index 0f3b252c07..5fecbd3d41 100644 --- a/vendor/github.com/godbus/dbus/v5/auth.go +++ b/vendor/github.com/godbus/dbus/v5/auth.go @@ -83,9 +83,9 @@ func (conn *Conn) Auth(methods []Auth) error { } switch status { case AuthOk: - err, ok = conn.tryAuth(m, waitingForOk, in) + ok, err = conn.tryAuth(m, waitingForOk, in) case AuthContinue: - err, ok = conn.tryAuth(m, waitingForData, in) + ok, err = conn.tryAuth(m, waitingForData, in) default: panic("dbus: invalid authentication status") } @@ -125,21 +125,21 @@ func (conn *Conn) Auth(methods []Auth) error { } // tryAuth tries to authenticate with m as the mechanism, using state as the -// initial authState and in for reading input. It returns (nil, true) on -// success, (nil, false) on a REJECTED and (someErr, false) if some other +// initial authState and in for reading input. It returns (true, nil) on +// success, (false, nil) on a REJECTED and (false, someErr) if some other // error occurred. -func (conn *Conn) tryAuth(m Auth, state authState, in *bufio.Reader) (error, bool) { +func (conn *Conn) tryAuth(m Auth, state authState, in *bufio.Reader) (bool, error) { for { s, err := authReadLine(in) if err != nil { - return err, false + return false, err } switch { case state == waitingForData && string(s[0]) == "DATA": if len(s) != 2 { err = authWriteLine(conn.transport, []byte("ERROR")) if err != nil { - return err, false + return false, err } continue } @@ -149,7 +149,7 @@ func (conn *Conn) tryAuth(m Auth, state authState, in *bufio.Reader) (error, boo if len(data) != 0 { err = authWriteLine(conn.transport, []byte("DATA"), data) if err != nil { - return err, false + return false, err } } if status == AuthOk { @@ -158,66 +158,66 @@ func (conn *Conn) tryAuth(m Auth, state authState, in *bufio.Reader) (error, boo case AuthError: err = authWriteLine(conn.transport, []byte("ERROR")) if err != nil { - return err, false + return false, err } } case state == waitingForData && string(s[0]) == "REJECTED": - return nil, false + return false, nil case state == waitingForData && string(s[0]) == "ERROR": err = authWriteLine(conn.transport, []byte("CANCEL")) if err != nil { - return err, false + return false, err } state = waitingForReject case state == waitingForData && string(s[0]) == "OK": if len(s) != 2 { err = authWriteLine(conn.transport, []byte("CANCEL")) if err != nil { - return err, false + return false, err } state = waitingForReject } else { conn.uuid = string(s[1]) - return nil, true + return true, nil } case state == waitingForData: err = authWriteLine(conn.transport, []byte("ERROR")) if err != nil { - return err, false + return false, err } case state == waitingForOk && string(s[0]) == "OK": if len(s) != 2 { err = authWriteLine(conn.transport, []byte("CANCEL")) if err != nil { - return err, false + return false, err } state = waitingForReject } else { conn.uuid = string(s[1]) - return nil, true + return true, nil } case state == waitingForOk && string(s[0]) == "DATA": err = authWriteLine(conn.transport, []byte("DATA")) if err != nil { - return err, false + return false, nil } case state == waitingForOk && string(s[0]) == "REJECTED": - return nil, false + return false, nil case state == waitingForOk && string(s[0]) == "ERROR": err = authWriteLine(conn.transport, []byte("CANCEL")) if err != nil { - return err, false + return false, err } state = waitingForReject case state == waitingForOk: err = authWriteLine(conn.transport, []byte("ERROR")) if err != nil { - return err, false + return false, err } case state == waitingForReject && string(s[0]) == "REJECTED": - return nil, false + return false, nil case state == waitingForReject: - return errors.New("dbus: authentication protocol error"), false + return false, errors.New("dbus: authentication protocol error") default: panic("dbus: invalid auth state") } diff --git a/vendor/github.com/godbus/dbus/v5/call.go b/vendor/github.com/godbus/dbus/v5/call.go index b06b063580..ac01ec669f 100644 --- a/vendor/github.com/godbus/dbus/v5/call.go +++ b/vendor/github.com/godbus/dbus/v5/call.go @@ -2,11 +2,8 @@ package dbus import ( "context" - "errors" ) -var errSignature = errors.New("dbus: mismatched signature") - // Call represents a pending or completed method call. type Call struct { Destination string diff --git a/vendor/github.com/godbus/dbus/v5/conn.go b/vendor/github.com/godbus/dbus/v5/conn.go index 69978ea26a..bbe111b228 100644 --- a/vendor/github.com/godbus/dbus/v5/conn.go +++ b/vendor/github.com/godbus/dbus/v5/conn.go @@ -76,7 +76,6 @@ func SessionBus() (conn *Conn, err error) { func getSessionBusAddress(autolaunch bool) (string, error) { if address := os.Getenv("DBUS_SESSION_BUS_ADDRESS"); address != "" && address != "autolaunch:" { return address, nil - } else if address := tryDiscoverDbusSessionBusAddress(); address != "" { os.Setenv("DBUS_SESSION_BUS_ADDRESS", address) return address, nil @@ -485,7 +484,7 @@ func (conn *Conn) Object(dest string, path ObjectPath) BusObject { return &Object{conn, dest, path} } -func (conn *Conn) sendMessageAndIfClosed(msg *Message, ifClosed func()) { +func (conn *Conn) sendMessageAndIfClosed(msg *Message, ifClosed func()) error { if msg.serial == 0 { msg.serial = conn.getSerial() } @@ -498,6 +497,7 @@ func (conn *Conn) sendMessageAndIfClosed(msg *Message, ifClosed func()) { } else if msg.Type != TypeMethodCall { conn.serialGen.RetireSerial(msg.serial) } + return err } func (conn *Conn) handleSendError(msg *Message, err error) { @@ -505,6 +505,9 @@ func (conn *Conn) handleSendError(msg *Message, err error) { conn.calls.handleSendError(msg, err) } else if msg.Type == TypeMethodReply { if _, ok := err.(FormatError); ok { + // Make sure that the caller gets some kind of error response if + // the application code tried to respond, but the resulting message + // was malformed in the end conn.sendError(err, msg.Headers[FieldDestination].value.(string), msg.Headers[FieldReplySerial].value.(uint32)) } } @@ -560,7 +563,8 @@ func (conn *Conn) send(ctx context.Context, msg *Message, ch chan *Call) *Call { <-ctx.Done() conn.calls.handleSendError(msg, ctx.Err()) }() - conn.sendMessageAndIfClosed(msg, func() { + // error is handled in handleSendError + _ = conn.sendMessageAndIfClosed(msg, func() { conn.calls.handleSendError(msg, ErrClosed) canceler() }) @@ -568,7 +572,8 @@ func (conn *Conn) send(ctx context.Context, msg *Message, ch chan *Call) *Call { canceler() call = &Call{Err: nil, Done: ch} ch <- call - conn.sendMessageAndIfClosed(msg, func() { + // error is handled in handleSendError + _ = conn.sendMessageAndIfClosed(msg, func() { call = &Call{Err: ErrClosed} }) } @@ -602,7 +607,8 @@ func (conn *Conn) sendError(err error, dest string, serial uint32) { if len(e.Body) > 0 { msg.Headers[FieldSignature] = MakeVariant(SignatureOf(e.Body...)) } - conn.sendMessageAndIfClosed(msg, nil) + // not much we can do to handle a possible error here + _ = conn.sendMessageAndIfClosed(msg, nil) } // sendReply creates a method reply message corresponding to the parameters and @@ -619,7 +625,8 @@ func (conn *Conn) sendReply(dest string, serial uint32, values ...interface{}) { if len(values) > 0 { msg.Headers[FieldSignature] = MakeVariant(SignatureOf(values...)) } - conn.sendMessageAndIfClosed(msg, nil) + // not much we can do to handle a possible error here + _ = conn.sendMessageAndIfClosed(msg, nil) } // AddMatchSignal registers the given match rule to receive broadcast @@ -630,7 +637,7 @@ func (conn *Conn) AddMatchSignal(options ...MatchOption) error { // AddMatchSignalContext acts like AddMatchSignal but takes a context. func (conn *Conn) AddMatchSignalContext(ctx context.Context, options ...MatchOption) error { - options = append([]MatchOption{withMatchType("signal")}, options...) + options = append([]MatchOption{withMatchTypeSignal()}, options...) return conn.busObj.CallWithContext( ctx, "org.freedesktop.DBus.AddMatch", 0, @@ -645,7 +652,7 @@ func (conn *Conn) RemoveMatchSignal(options ...MatchOption) error { // RemoveMatchSignalContext acts like RemoveMatchSignal but takes a context. func (conn *Conn) RemoveMatchSignalContext(ctx context.Context, options ...MatchOption) error { - options = append([]MatchOption{withMatchType("signal")}, options...) + options = append([]MatchOption{withMatchTypeSignal()}, options...) return conn.busObj.CallWithContext( ctx, "org.freedesktop.DBus.RemoveMatch", 0, @@ -740,9 +747,7 @@ type transport interface { SendMessage(*Message) error } -var ( - transports = make(map[string]func(string) (transport, error)) -) +var transports = make(map[string]func(string) (transport, error)) func getTransport(address string) (transport, error) { var err error @@ -853,16 +858,19 @@ type nameTracker struct { func newNameTracker() *nameTracker { return &nameTracker{names: map[string]struct{}{}} } + func (tracker *nameTracker) acquireUniqueConnectionName(name string) { tracker.lck.Lock() defer tracker.lck.Unlock() tracker.unique = name } + func (tracker *nameTracker) acquireName(name string) { tracker.lck.Lock() defer tracker.lck.Unlock() tracker.names[name] = struct{}{} } + func (tracker *nameTracker) loseName(name string) { tracker.lck.Lock() defer tracker.lck.Unlock() @@ -874,12 +882,14 @@ func (tracker *nameTracker) uniqueNameIsKnown() bool { defer tracker.lck.RUnlock() return tracker.unique != "" } + func (tracker *nameTracker) isKnownName(name string) bool { tracker.lck.RLock() defer tracker.lck.RUnlock() _, ok := tracker.names[name] return ok || name == tracker.unique } + func (tracker *nameTracker) listKnownNames() []string { tracker.lck.RLock() defer tracker.lck.RUnlock() @@ -941,17 +951,6 @@ func (tracker *callTracker) handleSendError(msg *Message, err error) { } } -// finalize was the only func that did not strobe Done -func (tracker *callTracker) finalize(sn uint32) { - tracker.lck.Lock() - defer tracker.lck.Unlock() - c, ok := tracker.calls[sn] - if ok { - delete(tracker.calls, sn) - c.ContextCancel() - } -} - func (tracker *callTracker) finalizeWithBody(sn uint32, sequence Sequence, body []interface{}) { tracker.lck.Lock() c, ok := tracker.calls[sn] diff --git a/vendor/github.com/godbus/dbus/v5/conn_darwin.go b/vendor/github.com/godbus/dbus/v5/conn_darwin.go index 6e2e402021..cb2325a01b 100644 --- a/vendor/github.com/godbus/dbus/v5/conn_darwin.go +++ b/vendor/github.com/godbus/dbus/v5/conn_darwin.go @@ -12,7 +12,6 @@ const defaultSystemBusAddress = "unix:path=/opt/local/var/run/dbus/system_bus_so func getSessionBusPlatformAddress() (string, error) { cmd := exec.Command("launchctl", "getenv", "DBUS_LAUNCHD_SESSION_BUS_SOCKET") b, err := cmd.CombinedOutput() - if err != nil { return "", err } diff --git a/vendor/github.com/godbus/dbus/v5/conn_other.go b/vendor/github.com/godbus/dbus/v5/conn_other.go index 90289ca85a..067e67cc5b 100644 --- a/vendor/github.com/godbus/dbus/v5/conn_other.go +++ b/vendor/github.com/godbus/dbus/v5/conn_other.go @@ -1,3 +1,4 @@ +//go:build !darwin // +build !darwin package dbus @@ -19,7 +20,6 @@ var execCommand = exec.Command func getSessionBusPlatformAddress() (string, error) { cmd := execCommand("dbus-launch") b, err := cmd.CombinedOutput() - if err != nil { return "", err } @@ -42,10 +42,10 @@ func getSessionBusPlatformAddress() (string, error) { // It tries different techniques employed by different operating systems, // returning the first valid address it finds, or an empty string. // -// * /run/user//bus if this exists, it *is* the bus socket. present on -// Ubuntu 18.04 -// * /run/user//dbus-session: if this exists, it can be parsed for the bus -// address. present on Ubuntu 16.04 +// - /run/user//bus if this exists, it *is* the bus socket. present on +// Ubuntu 18.04 +// - /run/user//dbus-session: if this exists, it can be parsed for the bus +// address. present on Ubuntu 16.04 // // See https://dbus.freedesktop.org/doc/dbus-launch.1.html func tryDiscoverDbusSessionBusAddress() string { diff --git a/vendor/github.com/godbus/dbus/v5/conn_unix.go b/vendor/github.com/godbus/dbus/v5/conn_unix.go index 58aee7d2af..1a0daa6566 100644 --- a/vendor/github.com/godbus/dbus/v5/conn_unix.go +++ b/vendor/github.com/godbus/dbus/v5/conn_unix.go @@ -1,4 +1,5 @@ -//+build !windows,!solaris,!darwin +//go:build !windows && !solaris && !darwin +// +build !windows,!solaris,!darwin package dbus diff --git a/vendor/github.com/godbus/dbus/v5/conn_windows.go b/vendor/github.com/godbus/dbus/v5/conn_windows.go index 4291e4519c..fa839d2a22 100644 --- a/vendor/github.com/godbus/dbus/v5/conn_windows.go +++ b/vendor/github.com/godbus/dbus/v5/conn_windows.go @@ -1,5 +1,3 @@ -//+build windows - package dbus import "os" diff --git a/vendor/github.com/godbus/dbus/v5/dbus.go b/vendor/github.com/godbus/dbus/v5/dbus.go index c188d10485..8f152dc2f3 100644 --- a/vendor/github.com/godbus/dbus/v5/dbus.go +++ b/vendor/github.com/godbus/dbus/v5/dbus.go @@ -10,11 +10,8 @@ import ( var ( byteType = reflect.TypeOf(byte(0)) boolType = reflect.TypeOf(false) - uint8Type = reflect.TypeOf(uint8(0)) int16Type = reflect.TypeOf(int16(0)) uint16Type = reflect.TypeOf(uint16(0)) - intType = reflect.TypeOf(int(0)) - uintType = reflect.TypeOf(uint(0)) int32Type = reflect.TypeOf(int32(0)) uint32Type = reflect.TypeOf(uint32(0)) int64Type = reflect.TypeOf(int64(0)) @@ -85,7 +82,7 @@ func storeBase(dest, src reflect.Value) error { func setDest(dest, src reflect.Value) error { if !isVariant(src.Type()) && isVariant(dest.Type()) { - //special conversion for dbus.Variant + // special conversion for dbus.Variant dest.Set(reflect.ValueOf(MakeVariant(src.Interface()))) return nil } @@ -166,8 +163,8 @@ func storeMapIntoVariant(dest, src reflect.Value) error { func storeMapIntoInterface(dest, src reflect.Value) error { var dv reflect.Value if isVariant(src.Type().Elem()) { - //Convert variants to interface{} recursively when converting - //to interface{} + // Convert variants to interface{} recursively when converting + // to interface{} dv = reflect.MakeMap( reflect.MapOf(src.Type().Key(), interfaceType)) } else { @@ -200,7 +197,7 @@ func storeMapIntoMap(dest, src reflect.Value) error { func storeSlice(dest, src reflect.Value) error { switch { case src.Type() == interfacesType && dest.Kind() == reflect.Struct: - //The decoder always decodes structs as slices of interface{} + // The decoder always decodes structs as slices of interface{} return storeStruct(dest, src) case !kindsAreCompatible(dest.Type(), src.Type()): return fmt.Errorf( @@ -260,8 +257,8 @@ func storeSliceIntoVariant(dest, src reflect.Value) error { func storeSliceIntoInterface(dest, src reflect.Value) error { var dv reflect.Value if isVariant(src.Type().Elem()) { - //Convert variants to interface{} recursively when converting - //to interface{} + // Convert variants to interface{} recursively when converting + // to interface{} dv = reflect.MakeSlice(reflect.SliceOf(interfaceType), src.Len(), src.Cap()) } else { @@ -334,7 +331,7 @@ func (o ObjectPath) IsValid() bool { } // A UnixFD is a Unix file descriptor sent over the wire. See the package-level -// documentation for more information about Unix file descriptor passsing. +// documentation for more information about Unix file descriptor passing. type UnixFD int32 // A UnixFDIndex is the representation of a Unix file descriptor in a message. diff --git a/vendor/github.com/godbus/dbus/v5/decoder.go b/vendor/github.com/godbus/dbus/v5/decoder.go index 89bfed9d1a..97a827b83b 100644 --- a/vendor/github.com/godbus/dbus/v5/decoder.go +++ b/vendor/github.com/godbus/dbus/v5/decoder.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "io" "reflect" + "unsafe" ) type decoder struct { @@ -11,6 +12,12 @@ type decoder struct { order binary.ByteOrder pos int fds []int + + // The following fields are used to reduce memory allocs. + conv *stringConverter + buf []byte + d float64 + y [1]byte } // newDecoder returns a new decoder that reads values from in. The input is @@ -20,17 +27,27 @@ func newDecoder(in io.Reader, order binary.ByteOrder, fds []int) *decoder { dec.in = in dec.order = order dec.fds = fds + dec.conv = newStringConverter(stringConverterBufferSize) return dec } +// Reset resets the decoder to be reading from in. +func (dec *decoder) Reset(in io.Reader, order binary.ByteOrder, fds []int) { + dec.in = in + dec.order = order + dec.pos = 0 + dec.fds = fds + + if dec.conv == nil { + dec.conv = newStringConverter(stringConverterBufferSize) + } +} + // align aligns the input to the given boundary and panics on error. func (dec *decoder) align(n int) { if dec.pos%n != 0 { newpos := (dec.pos + n - 1) & ^(n - 1) - empty := make([]byte, newpos-dec.pos) - if _, err := io.ReadFull(dec.in, empty); err != nil { - panic(err) - } + dec.read2buf(newpos - dec.pos) dec.pos = newpos } } @@ -66,79 +83,89 @@ func (dec *decoder) Decode(sig Signature) (vs []interface{}, err error) { return vs, nil } +// read2buf reads exactly n bytes from the reader dec.in into the buffer dec.buf +// to reduce memory allocs. +// The buffer grows automatically. +func (dec *decoder) read2buf(n int) { + if cap(dec.buf) < n { + dec.buf = make([]byte, n) + } else { + dec.buf = dec.buf[:n] + } + if _, err := io.ReadFull(dec.in, dec.buf); err != nil { + panic(err) + } +} + +// decodeU decodes uint32 obtained from the reader dec.in. +// The goal is to reduce memory allocs. +func (dec *decoder) decodeU() uint32 { + dec.align(4) + dec.read2buf(4) + dec.pos += 4 + return dec.order.Uint32(dec.buf) +} + func (dec *decoder) decode(s string, depth int) interface{} { dec.align(alignment(typeFor(s))) switch s[0] { case 'y': - var b [1]byte - if _, err := dec.in.Read(b[:]); err != nil { + if _, err := dec.in.Read(dec.y[:]); err != nil { panic(err) } dec.pos++ - return b[0] + return dec.y[0] case 'b': - i := dec.decode("u", depth).(uint32) - switch { - case i == 0: + switch dec.decodeU() { + case 0: return false - case i == 1: + case 1: return true default: panic(FormatError("invalid value for boolean")) } case 'n': - var i int16 - dec.binread(&i) + dec.read2buf(2) dec.pos += 2 - return i + return int16(dec.order.Uint16(dec.buf)) case 'i': - var i int32 - dec.binread(&i) + dec.read2buf(4) dec.pos += 4 - return i + return int32(dec.order.Uint32(dec.buf)) case 'x': - var i int64 - dec.binread(&i) + dec.read2buf(8) dec.pos += 8 - return i + return int64(dec.order.Uint64(dec.buf)) case 'q': - var i uint16 - dec.binread(&i) + dec.read2buf(2) dec.pos += 2 - return i + return dec.order.Uint16(dec.buf) case 'u': - var i uint32 - dec.binread(&i) - dec.pos += 4 - return i + return dec.decodeU() case 't': - var i uint64 - dec.binread(&i) + dec.read2buf(8) dec.pos += 8 - return i + return dec.order.Uint64(dec.buf) case 'd': - var f float64 - dec.binread(&f) + dec.binread(&dec.d) dec.pos += 8 - return f + return dec.d case 's': - length := dec.decode("u", depth).(uint32) - b := make([]byte, int(length)+1) - if _, err := io.ReadFull(dec.in, b); err != nil { - panic(err) - } - dec.pos += int(length) + 1 - return string(b[:len(b)-1]) + length := dec.decodeU() + p := int(length) + 1 + dec.read2buf(p) + dec.pos += p + return dec.conv.String(dec.buf[:len(dec.buf)-1]) case 'o': return ObjectPath(dec.decode("s", depth).(string)) case 'g': length := dec.decode("y", depth).(byte) - b := make([]byte, int(length)+1) - if _, err := io.ReadFull(dec.in, b); err != nil { - panic(err) - } - dec.pos += int(length) + 1 - sig, err := ParseSignature(string(b[:len(b)-1])) + p := int(length) + 1 + dec.read2buf(p) + dec.pos += p + sig, err := ParseSignature( + dec.conv.String(dec.buf[:len(dec.buf)-1]), + ) if err != nil { panic(err) } @@ -163,7 +190,7 @@ func (dec *decoder) decode(s string, depth int) interface{} { variant.value = dec.decode(sig.str, depth+1) return variant case 'h': - idx := dec.decode("u", depth).(uint32) + idx := dec.decodeU() if int(idx) < len(dec.fds) { return UnixFD(dec.fds[idx]) } @@ -176,7 +203,7 @@ func (dec *decoder) decode(s string, depth int) interface{} { if depth >= 63 { panic(FormatError("input exceeds container depth limit")) } - length := dec.decode("u", depth).(uint32) + length := dec.decodeU() // Even for empty maps, the correct padding must be included dec.align(8) spos := dec.pos @@ -195,7 +222,7 @@ func (dec *decoder) decode(s string, depth int) interface{} { panic(FormatError("input exceeds container depth limit")) } sig := s[1:] - length := dec.decode("u", depth).(uint32) + length := dec.decodeU() // capacity can be determined only for fixed-size element types var capacity int if s := sigByteSize(sig); s != 0 { @@ -205,9 +232,9 @@ func (dec *decoder) decode(s string, depth int) interface{} { // Even for empty arrays, the correct padding must be included align := alignment(typeFor(s[1:])) if len(s) > 1 && s[1] == '(' { - //Special case for arrays of structs - //structs decode as a slice of interface{} values - //but the dbus alignment does not match this + // Special case for arrays of structs + // structs decode as a slice of interface{} values + // but the dbus alignment does not match this align = 8 } dec.align(align) @@ -290,3 +317,65 @@ type FormatError string func (e FormatError) Error() string { return "dbus: wire format error: " + string(e) } + +// stringConverterBufferSize defines the recommended buffer size of 4KB. +// It showed good results in a benchmark when decoding 35KB message, +// see https://github.com/marselester/systemd#testing. +const stringConverterBufferSize = 4096 + +func newStringConverter(capacity int) *stringConverter { + return &stringConverter{ + buf: make([]byte, 0, capacity), + offset: 0, + } +} + +// stringConverter converts bytes to strings with less allocs. +// The idea is to accumulate bytes in a buffer with specified capacity +// and create strings with unsafe package using bytes from a buffer. +// For example, 10 "fizz" strings written to a 40-byte buffer +// will result in 1 alloc instead of 10. +// +// Once a buffer is filled, a new one is created with the same capacity. +// Old buffers will be eventually GC-ed +// with no side effects to the returned strings. +type stringConverter struct { + // buf is a temporary buffer where decoded strings are batched. + buf []byte + // offset is a buffer position where the last string was written. + offset int +} + +// String converts bytes to a string. +func (c *stringConverter) String(b []byte) string { + n := len(b) + if n == 0 { + return "" + } + // Must allocate because a string doesn't fit into the buffer. + if n > cap(c.buf) { + return string(b) + } + + if len(c.buf)+n > cap(c.buf) { + c.buf = make([]byte, 0, cap(c.buf)) + c.offset = 0 + } + c.buf = append(c.buf, b...) + + b = c.buf[c.offset:] + s := toString(b) + c.offset += n + return s +} + +// toString converts a byte slice to a string without allocating. +// Starting from Go 1.20 you should use unsafe.String. +func toString(b []byte) string { + var s string + h := (*reflect.StringHeader)(unsafe.Pointer(&s)) + h.Data = uintptr(unsafe.Pointer(&b[0])) + h.Len = len(b) + + return s +} diff --git a/vendor/github.com/godbus/dbus/v5/default_handler.go b/vendor/github.com/godbus/dbus/v5/default_handler.go index 13132c6b47..3da16b1ecb 100644 --- a/vendor/github.com/godbus/dbus/v5/default_handler.go +++ b/vendor/github.com/godbus/dbus/v5/default_handler.go @@ -18,9 +18,9 @@ func newIntrospectIntf(h *defaultHandler) *exportedIntf { return newExportedIntf(methods, true) } -//NewDefaultHandler returns an instance of the default -//call handler. This is useful if you want to implement only -//one of the two handlers but not both. +// NewDefaultHandler returns an instance of the default +// call handler. This is useful if you want to implement only +// one of the two handlers but not both. // // Deprecated: this is the default value, don't use it, it will be unexported. func NewDefaultHandler() *defaultHandler { @@ -148,7 +148,7 @@ func (m exportedMethod) Call(args ...interface{}) ([]interface{}, error) { out[i] = val.Interface() } if nilErr || err == nil { - //concrete type to interface nil is a special case + // concrete type to interface nil is a special case return out, nil } return out, err @@ -215,10 +215,6 @@ func (obj *exportedObj) LookupMethod(name string) (Method, bool) { return nil, false } -func (obj *exportedObj) isFallbackInterface() bool { - return false -} - func newExportedIntf(methods map[string]Method, includeSubtree bool) *exportedIntf { return &exportedIntf{ methods: methods, @@ -242,9 +238,9 @@ func (obj *exportedIntf) isFallbackInterface() bool { return obj.includeSubtree } -//NewDefaultSignalHandler returns an instance of the default -//signal handler. This is useful if you want to implement only -//one of the two handlers but not both. +// NewDefaultSignalHandler returns an instance of the default +// signal handler. This is useful if you want to implement only +// one of the two handlers but not both. // // Deprecated: this is the default value, don't use it, it will be unexported. func NewDefaultSignalHandler() *defaultSignalHandler { diff --git a/vendor/github.com/godbus/dbus/v5/doc.go b/vendor/github.com/godbus/dbus/v5/doc.go index 8f25a00d61..09eedc71e6 100644 --- a/vendor/github.com/godbus/dbus/v5/doc.go +++ b/vendor/github.com/godbus/dbus/v5/doc.go @@ -7,7 +7,7 @@ on remote objects and emit or receive signals. Using the Export method, you can arrange D-Bus methods calls to be directly translated to method calls on a Go value. -Conversion Rules +# Conversion Rules For outgoing messages, Go types are automatically converted to the corresponding D-Bus types. See the official specification at @@ -15,25 +15,25 @@ https://dbus.freedesktop.org/doc/dbus-specification.html#type-system for more information on the D-Bus type system. The following types are directly encoded as their respective D-Bus equivalents: - Go type | D-Bus type - ------------+----------- - byte | BYTE - bool | BOOLEAN - int16 | INT16 - uint16 | UINT16 - int | INT32 - uint | UINT32 - int32 | INT32 - uint32 | UINT32 - int64 | INT64 - uint64 | UINT64 - float64 | DOUBLE - string | STRING - ObjectPath | OBJECT_PATH - Signature | SIGNATURE - Variant | VARIANT - interface{} | VARIANT - UnixFDIndex | UNIX_FD + Go type | D-Bus type + ------------+----------- + byte | BYTE + bool | BOOLEAN + int16 | INT16 + uint16 | UINT16 + int | INT32 + uint | UINT32 + int32 | INT32 + uint32 | UINT32 + int64 | INT64 + uint64 | UINT64 + float64 | DOUBLE + string | STRING + ObjectPath | OBJECT_PATH + Signature | SIGNATURE + Variant | VARIANT + interface{} | VARIANT + UnixFDIndex | UNIX_FD Slices and arrays encode as ARRAYs of their element type. @@ -57,7 +57,7 @@ of STRUCTs. Incoming STRUCTS are represented as a slice of empty interfaces containing the struct fields in the correct order. The Store function can be used to convert such values to Go structs. -Unix FD passing +# Unix FD passing Handling Unix file descriptors deserves special mention. To use them, you should first check that they are supported on a connection by calling SupportsUnixFDs. @@ -66,6 +66,5 @@ UnixFD's to messages that are accompanied by the given file descriptors with the UnixFD values being substituted by the correct indices. Similarly, the indices of incoming messages are automatically resolved. It shouldn't be necessary to use UnixFDIndex. - */ package dbus diff --git a/vendor/github.com/godbus/dbus/v5/export.go b/vendor/github.com/godbus/dbus/v5/export.go index d3dd9f7cd6..fa009d2ce4 100644 --- a/vendor/github.com/godbus/dbus/v5/export.go +++ b/vendor/github.com/godbus/dbus/v5/export.go @@ -205,15 +205,13 @@ func (conn *Conn) handleCall(msg *Message) { } reply.Headers[FieldReplySerial] = MakeVariant(msg.serial) reply.Body = make([]interface{}, len(ret)) - for i := 0; i < len(ret); i++ { - reply.Body[i] = ret[i] - } + copy(reply.Body, ret) reply.Headers[FieldSignature] = MakeVariant(SignatureOf(reply.Body...)) - if err := reply.IsValid(); err != nil { - fmt.Fprintf(os.Stderr, "dbus: dropping invalid reply to %s.%s on obj %s: %s\n", ifaceName, name, path, err) - } else { - conn.sendMessageAndIfClosed(reply, nil) + if err := conn.sendMessageAndIfClosed(reply, nil); err != nil { + if _, ok := err.(FormatError); ok { + fmt.Fprintf(os.Stderr, "dbus: replacing invalid reply to %s.%s on obj %s: %s\n", ifaceName, name, path, err) + } } } } @@ -237,18 +235,15 @@ func (conn *Conn) Emit(path ObjectPath, name string, values ...interface{}) erro if len(values) > 0 { msg.Headers[FieldSignature] = MakeVariant(SignatureOf(values...)) } - if err := msg.IsValid(); err != nil { - return err - } var closed bool - conn.sendMessageAndIfClosed(msg, func() { + err := conn.sendMessageAndIfClosed(msg, func() { closed = true }) if closed { return ErrClosed } - return nil + return err } // Export registers the given value to be exported as an object on the diff --git a/vendor/github.com/godbus/dbus/v5/match.go b/vendor/github.com/godbus/dbus/v5/match.go index 5a607e53e4..ffb0134475 100644 --- a/vendor/github.com/godbus/dbus/v5/match.go +++ b/vendor/github.com/godbus/dbus/v5/match.go @@ -26,10 +26,10 @@ func WithMatchOption(key, value string) MatchOption { return MatchOption{key, value} } -// doesn't make sense to export this option because clients can only -// subscribe to messages with signal type. -func withMatchType(typ string) MatchOption { - return WithMatchOption("type", typ) +// It does not make sense to have a public WithMatchType function +// because clients can only subscribe to messages with signal type. +func withMatchTypeSignal() MatchOption { + return WithMatchOption("type", "signal") } // WithMatchSender sets sender match option. diff --git a/vendor/github.com/godbus/dbus/v5/message.go b/vendor/github.com/godbus/dbus/v5/message.go index bdf43fdd6e..5ab6e9d9a1 100644 --- a/vendor/github.com/godbus/dbus/v5/message.go +++ b/vendor/github.com/godbus/dbus/v5/message.go @@ -158,7 +158,9 @@ func DecodeMessageWithFDs(rd io.Reader, fds []int) (msg *Message, err error) { if err != nil { return nil, err } - binary.Read(bytes.NewBuffer(b), order, &hlength) + if err := binary.Read(bytes.NewBuffer(b), order, &hlength); err != nil { + return nil, err + } if hlength+length+16 > 1<<27 { return nil, InvalidMessageError("message is too long") } @@ -186,7 +188,7 @@ func DecodeMessageWithFDs(rd io.Reader, fds []int) (msg *Message, err error) { } } - if err = msg.IsValid(); err != nil { + if err = msg.validateHeader(); err != nil { return nil, err } sig, _ := msg.Headers[FieldSignature].value.(Signature) @@ -265,12 +267,14 @@ func (msg *Message) EncodeToWithFDs(out io.Writer, order binary.ByteOrder) (fds return } enc.align(8) - body.WriteTo(&buf) + if _, err := body.WriteTo(&buf); err != nil { + return nil, err + } if buf.Len() > 1<<27 { - return make([]int, 0), InvalidMessageError("message is too long") + return nil, InvalidMessageError("message is too long") } if _, err := buf.WriteTo(out); err != nil { - return make([]int, 0), err + return nil, err } return enc.fds, nil } @@ -286,8 +290,7 @@ func (msg *Message) EncodeTo(out io.Writer, order binary.ByteOrder) (err error) // IsValid checks whether msg is a valid message and returns an // InvalidMessageError or FormatError if it is not. func (msg *Message) IsValid() error { - var b bytes.Buffer - return msg.EncodeTo(&b, nativeEndian) + return msg.EncodeTo(io.Discard, nativeEndian) } func (msg *Message) validateHeader() error { diff --git a/vendor/github.com/godbus/dbus/v5/object.go b/vendor/github.com/godbus/dbus/v5/object.go index 664abb7fba..b4b1e939b3 100644 --- a/vendor/github.com/godbus/dbus/v5/object.go +++ b/vendor/github.com/godbus/dbus/v5/object.go @@ -46,7 +46,7 @@ func (o *Object) CallWithContext(ctx context.Context, method string, flags Flags // Deprecated: use (*Conn) AddMatchSignal instead. func (o *Object) AddMatchSignal(iface, member string, options ...MatchOption) *Call { base := []MatchOption{ - withMatchType("signal"), + withMatchTypeSignal(), WithMatchInterface(iface), WithMatchMember(member), } @@ -65,7 +65,7 @@ func (o *Object) AddMatchSignal(iface, member string, options ...MatchOption) *C // Deprecated: use (*Conn) RemoveMatchSignal instead. func (o *Object) RemoveMatchSignal(iface, member string, options ...MatchOption) *Call { base := []MatchOption{ - withMatchType("signal"), + withMatchTypeSignal(), WithMatchInterface(iface), WithMatchMember(member), } @@ -151,7 +151,14 @@ func (o *Object) StoreProperty(p string, value interface{}) error { // SetProperty calls org.freedesktop.DBus.Properties.Set on the given // object. The property name must be given in interface.member notation. +// Panics if v is not a valid Variant type. func (o *Object) SetProperty(p string, v interface{}) error { + // v might already be a variant... + variant, ok := v.(Variant) + if !ok { + // Otherwise, make it into one. + variant = MakeVariant(v) + } idx := strings.LastIndex(p, ".") if idx == -1 || idx+1 == len(p) { return errors.New("dbus: invalid property " + p) @@ -160,7 +167,7 @@ func (o *Object) SetProperty(p string, v interface{}) error { iface := p[:idx] prop := p[idx+1:] - return o.Call("org.freedesktop.DBus.Properties.Set", 0, iface, prop, v).Err + return o.Call("org.freedesktop.DBus.Properties.Set", 0, iface, prop, variant).Err } // Destination returns the destination that calls on (o *Object) are sent to. diff --git a/vendor/github.com/godbus/dbus/v5/sequential_handler.go b/vendor/github.com/godbus/dbus/v5/sequential_handler.go index ef2fcdba17..886b5eb16b 100644 --- a/vendor/github.com/godbus/dbus/v5/sequential_handler.go +++ b/vendor/github.com/godbus/dbus/v5/sequential_handler.go @@ -93,7 +93,7 @@ func (scd *sequentialSignalChannelData) bufferSignals() { var queue []*Signal for { if len(queue) == 0 { - signal, ok := <- scd.in + signal, ok := <-scd.in if !ok { return } diff --git a/vendor/github.com/godbus/dbus/v5/server_interfaces.go b/vendor/github.com/godbus/dbus/v5/server_interfaces.go index e4e0389fdf..aef3c772fa 100644 --- a/vendor/github.com/godbus/dbus/v5/server_interfaces.go +++ b/vendor/github.com/godbus/dbus/v5/server_interfaces.go @@ -22,7 +22,7 @@ type Handler interface { // of Interface lookup is up to the implementation of // the ServerObject. The ServerObject implementation may // choose to implement empty string as a valid interface -// represeting all methods or not per the D-Bus specification. +// representing all methods or not per the D-Bus specification. type ServerObject interface { LookupInterface(name string) (Interface, bool) } diff --git a/vendor/github.com/godbus/dbus/v5/sig.go b/vendor/github.com/godbus/dbus/v5/sig.go index 6b9cadb5fb..5bd797b62f 100644 --- a/vendor/github.com/godbus/dbus/v5/sig.go +++ b/vendor/github.com/godbus/dbus/v5/sig.go @@ -183,19 +183,19 @@ func (cnt *depthCounter) Valid() bool { return cnt.arrayDepth <= 32 && cnt.structDepth <= 32 && cnt.dictEntryDepth <= 32 } -func (cnt depthCounter) EnterArray() *depthCounter { +func (cnt *depthCounter) EnterArray() *depthCounter { cnt.arrayDepth++ - return &cnt + return cnt } -func (cnt depthCounter) EnterStruct() *depthCounter { +func (cnt *depthCounter) EnterStruct() *depthCounter { cnt.structDepth++ - return &cnt + return cnt } -func (cnt depthCounter) EnterDictEntry() *depthCounter { +func (cnt *depthCounter) EnterDictEntry() *depthCounter { cnt.dictEntryDepth++ - return &cnt + return cnt } // Try to read a single type from this string. If it was successful, err is nil @@ -221,6 +221,9 @@ func validSingle(s string, depth *depthCounter) (err error, rem string) { i++ rem = s[i+1:] s = s[2:i] + if len(s) == 0 { + return SignatureError{Sig: s, Reason: "empty dict"}, "" + } if err, _ = validSingle(s[:1], depth.EnterArray().EnterDictEntry()); err != nil { return err, "" } diff --git a/vendor/github.com/godbus/dbus/v5/transport_nonce_tcp.go b/vendor/github.com/godbus/dbus/v5/transport_nonce_tcp.go index 697739efaf..a61a82084b 100644 --- a/vendor/github.com/godbus/dbus/v5/transport_nonce_tcp.go +++ b/vendor/github.com/godbus/dbus/v5/transport_nonce_tcp.go @@ -1,4 +1,5 @@ -//+build !windows +//go:build !windows +// +build !windows package dbus diff --git a/vendor/github.com/godbus/dbus/v5/transport_unix.go b/vendor/github.com/godbus/dbus/v5/transport_unix.go index 0a8c712ebd..6840387a5f 100644 --- a/vendor/github.com/godbus/dbus/v5/transport_unix.go +++ b/vendor/github.com/godbus/dbus/v5/transport_unix.go @@ -1,4 +1,5 @@ -//+build !windows,!solaris +//go:build !windows && !solaris +// +build !windows,!solaris package dbus @@ -11,10 +12,29 @@ import ( "syscall" ) +// msghead represents the part of the message header +// that has a constant size (byte order + 15 bytes). +type msghead struct { + Type Type + Flags Flags + Proto byte + BodyLen uint32 + Serial uint32 + HeaderLen uint32 +} + type oobReader struct { conn *net.UnixConn oob []byte buf [4096]byte + + // The following fields are used to reduce memory allocs. + headers []header + csheader []byte + b *bytes.Buffer + r *bytes.Reader + dec *decoder + msghead } func (o *oobReader) Read(b []byte) (n int, err error) { @@ -70,28 +90,36 @@ func (t *unixTransport) EnableUnixFDs() { } func (t *unixTransport) ReadMessage() (*Message, error) { - var ( - blen, hlen uint32 - csheader [16]byte - headers []header - order binary.ByteOrder - unixfds uint32 - ) // To be sure that all bytes of out-of-band data are read, we use a special // reader that uses ReadUnix on the underlying connection instead of Read // and gathers the out-of-band data in a buffer. if t.rdr == nil { - t.rdr = &oobReader{conn: t.UnixConn} + t.rdr = &oobReader{ + conn: t.UnixConn, + // This buffer is used to decode the part of the header that has a constant size. + csheader: make([]byte, 16), + b: &bytes.Buffer{}, + // The reader helps to read from the buffer several times. + r: &bytes.Reader{}, + dec: &decoder{}, + } } else { - t.rdr.oob = nil + t.rdr.oob = t.rdr.oob[:0] + t.rdr.headers = t.rdr.headers[:0] } + var ( + r = t.rdr.r + b = t.rdr.b + dec = t.rdr.dec + ) - // read the first 16 bytes (the part of the header that has a constant size), - // from which we can figure out the length of the rest of the message - if _, err := io.ReadFull(t.rdr, csheader[:]); err != nil { + _, err := io.ReadFull(t.rdr, t.rdr.csheader) + if err != nil { return nil, err } - switch csheader[0] { + + var order binary.ByteOrder + switch t.rdr.csheader[0] { case 'l': order = binary.LittleEndian case 'B': @@ -99,38 +127,62 @@ func (t *unixTransport) ReadMessage() (*Message, error) { default: return nil, InvalidMessageError("invalid byte order") } - // csheader[4:8] -> length of message body, csheader[12:16] -> length of - // header fields (without alignment) - binary.Read(bytes.NewBuffer(csheader[4:8]), order, &blen) - binary.Read(bytes.NewBuffer(csheader[12:]), order, &hlen) + + r.Reset(t.rdr.csheader[1:]) + if err := binary.Read(r, order, &t.rdr.msghead); err != nil { + return nil, err + } + + msg := &Message{ + Type: t.rdr.msghead.Type, + Flags: t.rdr.msghead.Flags, + serial: t.rdr.msghead.Serial, + } + // Length of header fields (without alignment). + hlen := t.rdr.msghead.HeaderLen if hlen%8 != 0 { hlen += 8 - (hlen % 8) } + if hlen+t.rdr.msghead.BodyLen+16 > 1<<27 { + return nil, InvalidMessageError("message is too long") + } - // decode headers and look for unix fds - headerdata := make([]byte, hlen+4) - copy(headerdata, csheader[12:]) - if _, err := io.ReadFull(t.rdr, headerdata[4:]); err != nil { + // Decode headers and look for unix fds. + b.Reset() + if _, err = b.Write(t.rdr.csheader[12:]); err != nil { + return nil, err + } + if _, err = io.CopyN(b, t.rdr, int64(hlen)); err != nil { return nil, err } - dec := newDecoder(bytes.NewBuffer(headerdata), order, make([]int, 0)) + dec.Reset(b, order, nil) dec.pos = 12 vs, err := dec.Decode(Signature{"a(yv)"}) if err != nil { return nil, err } - Store(vs, &headers) - for _, v := range headers { + if err = Store(vs, &t.rdr.headers); err != nil { + return nil, err + } + var unixfds uint32 + for _, v := range t.rdr.headers { if v.Field == byte(FieldUnixFDs) { unixfds, _ = v.Variant.value.(uint32) } } - all := make([]byte, 16+hlen+blen) - copy(all, csheader[:]) - copy(all[16:], headerdata[4:]) - if _, err := io.ReadFull(t.rdr, all[16+hlen:]); err != nil { + + msg.Headers = make(map[HeaderField]Variant) + for _, v := range t.rdr.headers { + msg.Headers[HeaderField(v.Field)] = v.Variant + } + + dec.align(8) + body := make([]byte, t.rdr.BodyLen) + if _, err = io.ReadFull(t.rdr, body); err != nil { return nil, err } + r.Reset(body) + if unixfds != 0 { if !t.hasUnixFDs { return nil, errors.New("dbus: got unix fds on unsupported transport") @@ -147,8 +199,8 @@ func (t *unixTransport) ReadMessage() (*Message, error) { if err != nil { return nil, err } - msg, err := DecodeMessageWithFDs(bytes.NewBuffer(all), fds) - if err != nil { + dec.Reset(r, order, fds) + if err = decodeMessageBody(msg, dec); err != nil { return nil, err } // substitute the values in the message body (which are indices for the @@ -173,7 +225,27 @@ func (t *unixTransport) ReadMessage() (*Message, error) { } return msg, nil } - return DecodeMessage(bytes.NewBuffer(all)) + + dec.Reset(r, order, nil) + if err = decodeMessageBody(msg, dec); err != nil { + return nil, err + } + return msg, nil +} + +func decodeMessageBody(msg *Message, dec *decoder) error { + if err := msg.validateHeader(); err != nil { + return err + } + + sig, _ := msg.Headers[FieldSignature].value.(Signature) + if sig.str == "" { + return nil + } + + var err error + msg.Body, err = dec.Decode(sig) + return err } func (t *unixTransport) SendMessage(msg *Message) error { diff --git a/vendor/github.com/godbus/dbus/v5/transport_unixcred_freebsd.go b/vendor/github.com/godbus/dbus/v5/transport_unixcred_freebsd.go index 1b5ed2089d..ff2284c838 100644 --- a/vendor/github.com/godbus/dbus/v5/transport_unixcred_freebsd.go +++ b/vendor/github.com/godbus/dbus/v5/transport_unixcred_freebsd.go @@ -7,39 +7,41 @@ package dbus -/* -const int sizeofPtr = sizeof(void*); -#define _WANT_UCRED -#include -#include -*/ -import "C" - import ( "io" "os" "syscall" "unsafe" + + "golang.org/x/sys/unix" ) // http://golang.org/src/pkg/syscall/ztypes_linux_amd64.go // https://golang.org/src/syscall/ztypes_freebsd_amd64.go +// +// Note: FreeBSD actually uses a 'struct cmsgcred' which starts with +// these fields and adds a list of the additional groups for the +// sender. type Ucred struct { - Pid int32 - Uid uint32 - Gid uint32 + Pid int32 + Uid uint32 + Euid uint32 + Gid uint32 } -// http://golang.org/src/pkg/syscall/types_linux.go -// https://golang.org/src/syscall/types_freebsd.go -// https://github.com/freebsd/freebsd/blob/master/sys/sys/ucred.h +// https://github.com/freebsd/freebsd/blob/master/sys/sys/socket.h +// +// The cmsgcred structure contains the above four fields, followed by +// a uint16 count of additional groups, uint16 padding to align and a +// 16 element array of uint32 for the additional groups. The size is +// the same across all supported platforms. const ( - SizeofUcred = C.sizeof_struct_ucred + SizeofCmsgcred = 84 // 4*4 + 2*2 + 16*4 ) // http://golang.org/src/pkg/syscall/sockcmsg_unix.go func cmsgAlignOf(salen int) int { - salign := C.sizeofPtr + salign := unix.SizeofPtr return (salen + salign - 1) & ^(salign - 1) } @@ -54,11 +56,11 @@ func cmsgData(h *syscall.Cmsghdr) unsafe.Pointer { // for sending to another process. This can be used for // authentication. func UnixCredentials(ucred *Ucred) []byte { - b := make([]byte, syscall.CmsgSpace(SizeofUcred)) + b := make([]byte, syscall.CmsgSpace(SizeofCmsgcred)) h := (*syscall.Cmsghdr)(unsafe.Pointer(&b[0])) h.Level = syscall.SOL_SOCKET h.Type = syscall.SCM_CREDS - h.SetLen(syscall.CmsgLen(SizeofUcred)) + h.SetLen(syscall.CmsgLen(SizeofCmsgcred)) *((*Ucred)(cmsgData(h))) = *ucred return b } diff --git a/vendor/github.com/godbus/dbus/v5/variant_parser.go b/vendor/github.com/godbus/dbus/v5/variant_parser.go index d20f5da6dd..9532e36f3c 100644 --- a/vendor/github.com/godbus/dbus/v5/variant_parser.go +++ b/vendor/github.com/godbus/dbus/v5/variant_parser.go @@ -417,7 +417,6 @@ func (b boolNode) Value(sig Signature) (interface{}, error) { type arrayNode struct { set sigSet children []varNode - val interface{} } func (n arrayNode) Infer() (Signature, error) { @@ -574,7 +573,6 @@ type dictEntry struct { type dictNode struct { kset, vset sigSet children []dictEntry - val interface{} } func (n dictNode) Infer() (Signature, error) { diff --git a/vendor/github.com/google/nftables/CONTRIBUTING.md b/vendor/github.com/google/nftables/CONTRIBUTING.md new file mode 100644 index 0000000000..ae319c70ac --- /dev/null +++ b/vendor/github.com/google/nftables/CONTRIBUTING.md @@ -0,0 +1,23 @@ +# How to Contribute + +We'd love to accept your patches and contributions to this project. There are +just a few small guidelines you need to follow. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution, +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. diff --git a/vendor/github.com/google/nftables/LICENSE b/vendor/github.com/google/nftables/LICENSE new file mode 100644 index 0000000000..d645695673 --- /dev/null +++ b/vendor/github.com/google/nftables/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/google/nftables/README.md b/vendor/github.com/google/nftables/README.md new file mode 100644 index 0000000000..cb633c7186 --- /dev/null +++ b/vendor/github.com/google/nftables/README.md @@ -0,0 +1,24 @@ +[![Build Status](https://github.com/google/nftables/actions/workflows/push.yml/badge.svg)](https://github.com/google/nftables/actions/workflows/push.yml) +[![GoDoc](https://godoc.org/github.com/google/nftables?status.svg)](https://godoc.org/github.com/google/nftables) + +**This is not the correct repository for issues with the Linux nftables +project!** This repository contains a third-party Go package to programmatically +interact with nftables. Find the official nftables website at +https://wiki.nftables.org/ + +This package manipulates Linux nftables (the iptables successor). It is +implemented in pure Go, i.e. does not wrap libnftnl. + +This is not an official Google product. + +## Breaking changes + +This package is in very early stages, and only contains enough data types and +functions to install very basic nftables rules. It is likely that mistakes with +the data types/API will be identified as more functionality is added. + +## Contributions + +Contributions are very welcome! + + diff --git a/vendor/github.com/google/nftables/alignedbuff/alignedbuff.go b/vendor/github.com/google/nftables/alignedbuff/alignedbuff.go new file mode 100644 index 0000000000..a97214649d --- /dev/null +++ b/vendor/github.com/google/nftables/alignedbuff/alignedbuff.go @@ -0,0 +1,300 @@ +// Package alignedbuff implements encoding and decoding aligned data elements +// to/from buffers in native endianess. +// +// # Note +// +// The alignment/padding as implemented in this package must match that of +// kernel's and user space C implementations for a particular architecture (bit +// size). Please see also the "dummy structure" _xt_align +// (https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/x_tables.h#L93) +// as well as the associated XT_ALIGN C preprocessor macro. +// +// In particular, we rely on the Go compiler to follow the same architecture +// alignments as the C compiler(s) on Linux. +package alignedbuff + +import ( + "bytes" + "errors" + "fmt" + "unsafe" + + "github.com/google/nftables/binaryutil" +) + +// ErrEOF signals trying to read beyond the available payload information. +var ErrEOF = errors.New("not enough data left") + +// AlignedBuff implements marshalling and unmarshalling information in +// platform/architecture native endianess and data type alignment. It +// additionally covers some of the nftables-xtables translation-specific +// idiosyncracies to the extend needed in order to properly marshal and +// unmarshal Match and Target expressions, and their Info payload in particular. +type AlignedBuff struct { + data []byte + pos int +} + +// New returns a new AlignedBuff for marshalling aligned data in native +// endianess. +func New() AlignedBuff { + return AlignedBuff{} +} + +// NewWithData returns a new AlignedBuff for unmarshalling the passed data in +// native endianess. +func NewWithData(data []byte) AlignedBuff { + return AlignedBuff{data: data} +} + +// Data returns the properly padded info payload data written before by calling +// the various Uint8, Uint16, ... marshalling functions. +func (a *AlignedBuff) Data() []byte { + // The Linux kernel expects payloads to be padded to the next uint64 + // alignment. + a.alignWrite(uint64AlignMask) + return a.data +} + +// BytesAligned32 unmarshals the given amount of bytes starting with the native +// alignment for uint32 data types. It returns ErrEOF when trying to read beyond +// the payload. +// +// BytesAligned32 is used to unmarshal IP addresses for different IP versions, +// which are always aligned the same way as the native alignment for uint32. +func (a *AlignedBuff) BytesAligned32(size int) ([]byte, error) { + if err := a.alignCheckedRead(uint32AlignMask); err != nil { + return nil, err + } + if a.pos > len(a.data)-size { + return nil, ErrEOF + } + data := a.data[a.pos : a.pos+size] + a.pos += size + return data, nil +} + +// Uint8 unmarshals an uint8 in native endianess and alignment. It returns +// ErrEOF when trying to read beyond the payload. +func (a *AlignedBuff) Uint8() (uint8, error) { + if a.pos >= len(a.data) { + return 0, ErrEOF + } + v := a.data[a.pos] + a.pos++ + return v, nil +} + +// Uint16 unmarshals an uint16 in native endianess and alignment. It returns +// ErrEOF when trying to read beyond the payload. +func (a *AlignedBuff) Uint16() (uint16, error) { + if err := a.alignCheckedRead(uint16AlignMask); err != nil { + return 0, err + } + v := binaryutil.NativeEndian.Uint16(a.data[a.pos : a.pos+2]) + a.pos += 2 + return v, nil +} + +// Uint16BE unmarshals an uint16 in "network" (=big endian) endianess and native +// uint16 alignment. It returns ErrEOF when trying to read beyond the payload. +func (a *AlignedBuff) Uint16BE() (uint16, error) { + if err := a.alignCheckedRead(uint16AlignMask); err != nil { + return 0, err + } + v := binaryutil.BigEndian.Uint16(a.data[a.pos : a.pos+2]) + a.pos += 2 + return v, nil +} + +// Uint32 unmarshals an uint32 in native endianess and alignment. It returns +// ErrEOF when trying to read beyond the payload. +func (a *AlignedBuff) Uint32() (uint32, error) { + if err := a.alignCheckedRead(uint32AlignMask); err != nil { + return 0, err + } + v := binaryutil.NativeEndian.Uint32(a.data[a.pos : a.pos+4]) + a.pos += 4 + return v, nil +} + +// Uint64 unmarshals an uint64 in native endianess and alignment. It returns +// ErrEOF when trying to read beyond the payload. +func (a *AlignedBuff) Uint64() (uint64, error) { + if err := a.alignCheckedRead(uint64AlignMask); err != nil { + return 0, err + } + v := binaryutil.NativeEndian.Uint64(a.data[a.pos : a.pos+8]) + a.pos += 8 + return v, nil +} + +// Int32 unmarshals an int32 in native endianess and alignment. It returns +// ErrEOF when trying to read beyond the payload. +func (a *AlignedBuff) Int32() (int32, error) { + if err := a.alignCheckedRead(int32AlignMask); err != nil { + return 0, err + } + v := binaryutil.Int32(a.data[a.pos : a.pos+4]) + a.pos += 4 + return v, nil +} + +// String unmarshals a null terminated string +func (a *AlignedBuff) String() (string, error) { + len := 0 + for { + if a.data[a.pos+len] == 0x00 { + break + } + len++ + } + + v := binaryutil.String(a.data[a.pos : a.pos+len]) + a.pos += len + return v, nil +} + +// StringWithLength unmarshals a string of a given length (for non-null +// terminated strings) +func (a *AlignedBuff) StringWithLength(len int) (string, error) { + v := binaryutil.String(a.data[a.pos : a.pos+len]) + a.pos += len + return v, nil +} + +// Uint unmarshals an uint in native endianess and alignment for the C "unsigned +// int" type. It returns ErrEOF when trying to read beyond the payload. Please +// note that on 64bit platforms, the size and alignment of C's and Go's unsigned +// integer data types differ, so we encapsulate this difference here. +func (a *AlignedBuff) Uint() (uint, error) { + switch uintSize { + case 2: + v, err := a.Uint16() + return uint(v), err + case 4: + v, err := a.Uint32() + return uint(v), err + case 8: + v, err := a.Uint64() + return uint(v), err + default: + panic(fmt.Sprintf("unsupported uint size %d", uintSize)) + } +} + +// PutBytesAligned32 marshals the given bytes starting with the native alignment +// for uint32 data types. It additionaly adds padding to reach the specified +// size. +// +// PutBytesAligned32 is used to marshal IP addresses for different IP versions, +// which are always aligned the same way as the native alignment for uint32. +func (a *AlignedBuff) PutBytesAligned32(data []byte, size int) { + a.alignWrite(uint32AlignMask) + a.data = append(a.data, data...) + a.pos += len(data) + if len(data) < size { + padding := size - len(data) + a.data = append(a.data, bytes.Repeat([]byte{0}, padding)...) + a.pos += padding + } +} + +// PutUint8 marshals an uint8 in native endianess and alignment. +func (a *AlignedBuff) PutUint8(v uint8) { + a.data = append(a.data, v) + a.pos++ +} + +// PutUint16 marshals an uint16 in native endianess and alignment. +func (a *AlignedBuff) PutUint16(v uint16) { + a.alignWrite(uint16AlignMask) + a.data = append(a.data, binaryutil.NativeEndian.PutUint16(v)...) + a.pos += 2 +} + +// PutUint16BE marshals an uint16 in "network" (=big endian) endianess and +// native uint16 alignment. +func (a *AlignedBuff) PutUint16BE(v uint16) { + a.alignWrite(uint16AlignMask) + a.data = append(a.data, binaryutil.BigEndian.PutUint16(v)...) + a.pos += 2 +} + +// PutUint32 marshals an uint32 in native endianess and alignment. +func (a *AlignedBuff) PutUint32(v uint32) { + a.alignWrite(uint32AlignMask) + a.data = append(a.data, binaryutil.NativeEndian.PutUint32(v)...) + a.pos += 4 +} + +// PutUint64 marshals an uint64 in native endianess and alignment. +func (a *AlignedBuff) PutUint64(v uint64) { + a.alignWrite(uint64AlignMask) + a.data = append(a.data, binaryutil.NativeEndian.PutUint64(v)...) + a.pos += 8 +} + +// PutInt32 marshals an int32 in native endianess and alignment. +func (a *AlignedBuff) PutInt32(v int32) { + a.alignWrite(int32AlignMask) + a.data = append(a.data, binaryutil.PutInt32(v)...) + a.pos += 4 +} + +// PutString marshals a string. +func (a *AlignedBuff) PutString(v string) { + a.data = append(a.data, binaryutil.PutString(v)...) + a.pos += len(v) +} + +// PutUint marshals an uint in native endianess and alignment for the C +// "unsigned int" type. Please note that on 64bit platforms, the size and +// alignment of C's and Go's unsigned integer data types differ, so we +// encapsulate this difference here. +func (a *AlignedBuff) PutUint(v uint) { + switch uintSize { + case 2: + a.PutUint16(uint16(v)) + case 4: + a.PutUint32(uint32(v)) + case 8: + a.PutUint64(uint64(v)) + default: + panic(fmt.Sprintf("unsupported uint size %d", uintSize)) + } +} + +// alignCheckedRead aligns the (read) position if necessary and suitable +// according to the specified alignment mask. alignCheckedRead returns an error +// if after any necessary alignment there isn't enough data left to be read into +// a value of the size corresponding to the specified alignment mask. +func (a *AlignedBuff) alignCheckedRead(m int) error { + a.pos = (a.pos + m) & ^m + if a.pos > len(a.data)-(m+1) { + return ErrEOF + } + return nil +} + +// alignWrite aligns the (write) position if necessary and suitable according to +// the specified alignment mask. It doubles as final payload padding helpmate in +// order to keep the kernel happy. +func (a *AlignedBuff) alignWrite(m int) { + pos := (a.pos + m) & ^m + if pos != a.pos { + a.data = append(a.data, padding[:pos-a.pos]...) + a.pos = pos + } +} + +// This is ... ugly. +var uint16AlignMask = int(unsafe.Alignof(uint16(0)) - 1) +var uint32AlignMask = int(unsafe.Alignof(uint32(0)) - 1) +var uint64AlignMask = int(unsafe.Alignof(uint64(0)) - 1) +var padding = bytes.Repeat([]byte{0}, uint64AlignMask) + +var int32AlignMask = int(unsafe.Alignof(int32(0)) - 1) + +// And this even worse. +var uintSize = unsafe.Sizeof(uint32(0)) diff --git a/vendor/github.com/google/nftables/binaryutil/binaryutil.go b/vendor/github.com/google/nftables/binaryutil/binaryutil.go new file mode 100644 index 0000000000..e61973f07a --- /dev/null +++ b/vendor/github.com/google/nftables/binaryutil/binaryutil.go @@ -0,0 +1,125 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package binaryutil contains convenience wrappers around encoding/binary. +package binaryutil + +import ( + "bytes" + "encoding/binary" + "unsafe" +) + +// ByteOrder is like binary.ByteOrder, but allocates memory and returns byte +// slices, for convenience. +type ByteOrder interface { + PutUint16(v uint16) []byte + PutUint32(v uint32) []byte + PutUint64(v uint64) []byte + Uint16(b []byte) uint16 + Uint32(b []byte) uint32 + Uint64(b []byte) uint64 +} + +// NativeEndian is either little endian or big endian, depending on the native +// endian-ness, and allocates memory and returns byte slices, for convenience. +var NativeEndian ByteOrder = &nativeEndian{} + +type nativeEndian struct{} + +func (nativeEndian) PutUint16(v uint16) []byte { + buf := make([]byte, 2) + *(*uint16)(unsafe.Pointer(&buf[0])) = v + return buf +} + +func (nativeEndian) PutUint32(v uint32) []byte { + buf := make([]byte, 4) + *(*uint32)(unsafe.Pointer(&buf[0])) = v + return buf +} + +func (nativeEndian) PutUint64(v uint64) []byte { + buf := make([]byte, 8) + *(*uint64)(unsafe.Pointer(&buf[0])) = v + return buf +} + +func (nativeEndian) Uint16(b []byte) uint16 { + return *(*uint16)(unsafe.Pointer(&b[0])) +} + +func (nativeEndian) Uint32(b []byte) uint32 { + return *(*uint32)(unsafe.Pointer(&b[0])) +} + +func (nativeEndian) Uint64(b []byte) uint64 { + return *(*uint64)(unsafe.Pointer(&b[0])) +} + +// BigEndian is like binary.BigEndian, but allocates memory and returns byte +// slices, for convenience. +var BigEndian ByteOrder = &bigEndian{} + +type bigEndian struct{} + +func (bigEndian) PutUint16(v uint16) []byte { + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, v) + return buf +} + +func (bigEndian) PutUint32(v uint32) []byte { + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, v) + return buf +} + +func (bigEndian) PutUint64(v uint64) []byte { + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, v) + return buf +} + +func (bigEndian) Uint16(b []byte) uint16 { + return binary.BigEndian.Uint16(b) +} + +func (bigEndian) Uint32(b []byte) uint32 { + return binary.BigEndian.Uint32(b) +} + +func (bigEndian) Uint64(b []byte) uint64 { + return binary.BigEndian.Uint64(b) +} + +// For dealing with types not supported by the encoding/binary interface + +func PutInt32(v int32) []byte { + buf := make([]byte, 4) + *(*int32)(unsafe.Pointer(&buf[0])) = v + return buf +} + +func Int32(b []byte) int32 { + return *(*int32)(unsafe.Pointer(&b[0])) +} + +func PutString(s string) []byte { + return []byte(s) +} + +func String(b []byte) string { + return string(bytes.TrimRight(b, "\x00")) +} diff --git a/vendor/github.com/google/nftables/chain.go b/vendor/github.com/google/nftables/chain.go new file mode 100644 index 0000000000..9928d63eff --- /dev/null +++ b/vendor/github.com/google/nftables/chain.go @@ -0,0 +1,283 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nftables + +import ( + "encoding/binary" + "fmt" + "math" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// ChainHook specifies at which step in packet processing the Chain should be +// executed. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Base_chain_hooks +type ChainHook uint32 + +// Possible ChainHook values. +var ( + ChainHookPrerouting *ChainHook = ChainHookRef(unix.NF_INET_PRE_ROUTING) + ChainHookInput *ChainHook = ChainHookRef(unix.NF_INET_LOCAL_IN) + ChainHookForward *ChainHook = ChainHookRef(unix.NF_INET_FORWARD) + ChainHookOutput *ChainHook = ChainHookRef(unix.NF_INET_LOCAL_OUT) + ChainHookPostrouting *ChainHook = ChainHookRef(unix.NF_INET_POST_ROUTING) + ChainHookIngress *ChainHook = ChainHookRef(unix.NF_NETDEV_INGRESS) +) + +// ChainHookRef returns a pointer to a ChainHookRef value. +func ChainHookRef(h ChainHook) *ChainHook { + return &h +} + +// ChainPriority orders the chain relative to Netfilter internal operations. See +// also +// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Base_chain_priority +type ChainPriority int32 + +// Possible ChainPriority values. +var ( // from /usr/include/linux/netfilter_ipv4.h + ChainPriorityFirst *ChainPriority = ChainPriorityRef(math.MinInt32) + ChainPriorityConntrackDefrag *ChainPriority = ChainPriorityRef(-400) + ChainPriorityRaw *ChainPriority = ChainPriorityRef(-300) + ChainPrioritySELinuxFirst *ChainPriority = ChainPriorityRef(-225) + ChainPriorityConntrack *ChainPriority = ChainPriorityRef(-200) + ChainPriorityMangle *ChainPriority = ChainPriorityRef(-150) + ChainPriorityNATDest *ChainPriority = ChainPriorityRef(-100) + ChainPriorityFilter *ChainPriority = ChainPriorityRef(0) + ChainPrioritySecurity *ChainPriority = ChainPriorityRef(50) + ChainPriorityNATSource *ChainPriority = ChainPriorityRef(100) + ChainPrioritySELinuxLast *ChainPriority = ChainPriorityRef(225) + ChainPriorityConntrackHelper *ChainPriority = ChainPriorityRef(300) + ChainPriorityConntrackConfirm *ChainPriority = ChainPriorityRef(math.MaxInt32) + ChainPriorityLast *ChainPriority = ChainPriorityRef(math.MaxInt32) +) + +// ChainPriorityRef returns a pointer to a ChainPriority value. +func ChainPriorityRef(p ChainPriority) *ChainPriority { + return &p +} + +// ChainType defines what this chain will be used for. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Base_chain_types +type ChainType string + +// Possible ChainType values. +const ( + ChainTypeFilter ChainType = "filter" + ChainTypeRoute ChainType = "route" + ChainTypeNAT ChainType = "nat" +) + +// ChainPolicy defines what this chain default policy will be. +type ChainPolicy uint32 + +// Possible ChainPolicy values. +const ( + ChainPolicyDrop ChainPolicy = iota + ChainPolicyAccept +) + +// A Chain contains Rules. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains +type Chain struct { + Name string + Table *Table + Hooknum *ChainHook + Priority *ChainPriority + Type ChainType + Policy *ChainPolicy +} + +// AddChain adds the specified Chain. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Adding_base_chains +func (cc *Conn) AddChain(c *Chain) *Chain { + cc.mu.Lock() + defer cc.mu.Unlock() + data := cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_CHAIN_TABLE, Data: []byte(c.Table.Name + "\x00")}, + {Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")}, + }) + + if c.Hooknum != nil && c.Priority != nil { + hookAttr := []netlink.Attribute{ + {Type: unix.NFTA_HOOK_HOOKNUM, Data: binaryutil.BigEndian.PutUint32(uint32(*c.Hooknum))}, + {Type: unix.NFTA_HOOK_PRIORITY, Data: binaryutil.BigEndian.PutUint32(uint32(*c.Priority))}, + } + data = append(data, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NLA_F_NESTED | unix.NFTA_CHAIN_HOOK, Data: cc.marshalAttr(hookAttr)}, + })...) + } + + if c.Policy != nil { + data = append(data, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_CHAIN_POLICY, Data: binaryutil.BigEndian.PutUint32(uint32(*c.Policy))}, + })...) + } + if c.Type != "" { + data = append(data, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_CHAIN_TYPE, Data: []byte(c.Type + "\x00")}, + })...) + } + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN), + Flags: netlink.Request | netlink.Acknowledge | netlink.Create, + }, + Data: append(extraHeader(uint8(c.Table.Family), 0), data...), + }) + + return c +} + +// DelChain deletes the specified Chain. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Deleting_chains +func (cc *Conn) DelChain(c *Chain) { + cc.mu.Lock() + defer cc.mu.Unlock() + data := cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_CHAIN_TABLE, Data: []byte(c.Table.Name + "\x00")}, + {Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")}, + }) + + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELCHAIN), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(uint8(c.Table.Family), 0), data...), + }) +} + +// FlushChain removes all rules within the specified Chain. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Flushing_chain +func (cc *Conn) FlushChain(c *Chain) { + cc.mu.Lock() + defer cc.mu.Unlock() + data := cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_RULE_TABLE, Data: []byte(c.Table.Name + "\x00")}, + {Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")}, + }) + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(uint8(c.Table.Family), 0), data...), + }) +} + +// ListChains returns currently configured chains in the kernel +func (cc *Conn) ListChains() ([]*Chain, error) { + return cc.ListChainsOfTableFamily(TableFamilyUnspecified) +} + +// ListChainsOfTableFamily returns currently configured chains for the specified +// family in the kernel. It lists all chains ins all tables if family is +// TableFamilyUnspecified. +func (cc *Conn) ListChainsOfTableFamily(family TableFamily) ([]*Chain, error) { + conn, closer, err := cc.netlinkConn() + if err != nil { + return nil, err + } + defer func() { _ = closer() }() + + msg := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETCHAIN), + Flags: netlink.Request | netlink.Dump, + }, + Data: extraHeader(uint8(family), 0), + } + + response, err := conn.Execute(msg) + if err != nil { + return nil, err + } + + var chains []*Chain + for _, m := range response { + c, err := chainFromMsg(m) + if err != nil { + return nil, err + } + + chains = append(chains, c) + } + + return chains, nil +} + +func chainFromMsg(msg netlink.Message) (*Chain, error) { + chainHeaderType := netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN) + if got, want := msg.Header.Type, chainHeaderType; got != want { + return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + } + + var c Chain + + ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) + if err != nil { + return nil, err + } + + for ad.Next() { + switch ad.Type() { + case unix.NFTA_CHAIN_NAME: + c.Name = ad.String() + case unix.NFTA_TABLE_NAME: + c.Table = &Table{Name: ad.String()} + // msg[0] carries TableFamily byte indicating whether it is IPv4, IPv6 or something else + c.Table.Family = TableFamily(msg.Data[0]) + case unix.NFTA_CHAIN_TYPE: + c.Type = ChainType(ad.String()) + case unix.NFTA_CHAIN_POLICY: + policy := ChainPolicy(binaryutil.BigEndian.Uint32(ad.Bytes())) + c.Policy = &policy + case unix.NFTA_CHAIN_HOOK: + ad.Do(func(b []byte) error { + c.Hooknum, c.Priority, err = hookFromMsg(b) + return err + }) + } + } + + return &c, nil +} + +func hookFromMsg(b []byte) (*ChainHook, *ChainPriority, error) { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return nil, nil, err + } + + ad.ByteOrder = binary.BigEndian + + var hooknum ChainHook + var prio ChainPriority + + for ad.Next() { + switch ad.Type() { + case unix.NFTA_HOOK_HOOKNUM: + hooknum = ChainHook(ad.Uint32()) + case unix.NFTA_HOOK_PRIORITY: + prio = ChainPriority(ad.Uint32()) + } + } + + return &hooknum, &prio, nil +} diff --git a/vendor/github.com/google/nftables/compat_policy.go b/vendor/github.com/google/nftables/compat_policy.go new file mode 100644 index 0000000000..c1f390855d --- /dev/null +++ b/vendor/github.com/google/nftables/compat_policy.go @@ -0,0 +1,89 @@ +package nftables + +import ( + "fmt" + + "github.com/google/nftables/expr" + "golang.org/x/sys/unix" +) + +const nft_RULE_COMPAT_F_INV uint32 = (1 << 1) +const nft_RULE_COMPAT_F_MASK uint32 = nft_RULE_COMPAT_F_INV + +// Used by xt match or target like xt_tcpudp to set compat policy between xtables and nftables +// https://elixir.bootlin.com/linux/v5.12/source/net/netfilter/nft_compat.c#L187 +type compatPolicy struct { + Proto uint32 + Flag uint32 +} + +var xtMatchCompatMap map[string]*compatPolicy = map[string]*compatPolicy{ + "tcp": { + Proto: unix.IPPROTO_TCP, + }, + "udp": { + Proto: unix.IPPROTO_UDP, + }, + "udplite": { + Proto: unix.IPPROTO_UDPLITE, + }, + "tcpmss": { + Proto: unix.IPPROTO_TCP, + }, + "sctp": { + Proto: unix.IPPROTO_SCTP, + }, + "osf": { + Proto: unix.IPPROTO_TCP, + }, + "ipcomp": { + Proto: unix.IPPROTO_COMP, + }, + "esp": { + Proto: unix.IPPROTO_ESP, + }, +} + +var xtTargetCompatMap map[string]*compatPolicy = map[string]*compatPolicy{ + "TCPOPTSTRIP": { + Proto: unix.IPPROTO_TCP, + }, + "TCPMSS": { + Proto: unix.IPPROTO_TCP, + }, +} + +func getCompatPolicy(exprs []expr.Any) (*compatPolicy, error) { + var exprItem expr.Any + var compat *compatPolicy + + for _, iter := range exprs { + var tmpExprItem expr.Any + var tmpCompat *compatPolicy + switch item := iter.(type) { + case *expr.Match: + if compat, ok := xtMatchCompatMap[item.Name]; ok { + tmpCompat = compat + tmpExprItem = item + } else { + continue + } + case *expr.Target: + if compat, ok := xtTargetCompatMap[item.Name]; ok { + tmpCompat = compat + tmpExprItem = item + } else { + continue + } + default: + continue + } + if compat == nil { + compat = tmpCompat + exprItem = tmpExprItem + } else if *compat != *tmpCompat { + return nil, fmt.Errorf("%#v and %#v has conflict compat policy %#v vs %#v", exprItem, tmpExprItem, compat, tmpCompat) + } + } + return compat, nil +} diff --git a/vendor/github.com/google/nftables/conn.go b/vendor/github.com/google/nftables/conn.go new file mode 100644 index 0000000000..711d7f6ae2 --- /dev/null +++ b/vendor/github.com/google/nftables/conn.go @@ -0,0 +1,315 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nftables + +import ( + "errors" + "fmt" + "sync" + + "github.com/google/nftables/binaryutil" + "github.com/google/nftables/expr" + "github.com/mdlayher/netlink" + "github.com/mdlayher/netlink/nltest" + "golang.org/x/sys/unix" +) + +// A Conn represents a netlink connection of the nftables family. +// +// All methods return their input, so that variables can be defined from string +// literals when desired. +// +// Commands are buffered. Flush sends all buffered commands in a single batch. +type Conn struct { + TestDial nltest.Func // for testing only; passed to nltest.Dial + NetNS int // fd referencing the network namespace netlink will interact with. + + lasting bool // establish a lasting connection to be used across multiple netlink operations. + mu sync.Mutex // protects the following state + messages []netlink.Message + err error + nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol. +} + +// ConnOption is an option to change the behavior of the nftables Conn returned by Open. +type ConnOption func(*Conn) + +// New returns a netlink connection for querying and modifying nftables. Some +// aspects of the new netlink connection can be configured using the options +// WithNetNSFd, WithTestDial, and AsLasting. +// +// A lasting netlink connection should be closed by calling CloseLasting() to +// close the underlying lasting netlink connection, cancelling all pending +// operations using this connection. +func New(opts ...ConnOption) (*Conn, error) { + cc := &Conn{} + for _, opt := range opts { + opt(cc) + } + + if !cc.lasting { + return cc, nil + } + + nlconn, err := cc.dialNetlink() + if err != nil { + return nil, err + } + cc.nlconn = nlconn + return cc, nil +} + +// AsLasting creates the new netlink connection as a lasting connection that is +// reused across multiple netlink operations, instead of opening and closing the +// underlying netlink connection only for the duration of a single netlink +// operation. +func AsLasting() ConnOption { + return func(cc *Conn) { + // We cannot create the underlying connection yet, as we are called + // anywhere in the option processing chain and there might be later + // options still modifying connection behavior. + cc.lasting = true + } +} + +// WithNetNSFd sets the network namespace to create a new netlink connection to: +// the fd must reference a network namespace. +func WithNetNSFd(fd int) ConnOption { + return func(cc *Conn) { + cc.NetNS = fd + } +} + +// WithTestDial sets the specified nltest.Func when creating a new netlink +// connection. +func WithTestDial(f nltest.Func) ConnOption { + return func(cc *Conn) { + cc.TestDial = f + } +} + +// netlinkCloser is returned by netlinkConn(UnderLock) and must be called after +// being done with the returned netlink connection in order to properly close +// this connection, if necessary. +type netlinkCloser func() error + +// netlinkConn returns a netlink connection together with a netlinkCloser that +// later must be called by the caller when it doesn't need the returned netlink +// connection anymore. The netlinkCloser will close the netlink connection when +// necessary. If New has been told to create a lasting connection, then this +// lasting netlink connection will be returned, otherwise a new "transient" +// netlink connection will be opened and returned instead. netlinkConn must not +// be called while the Conn.mu lock is currently helt (this will cause a +// deadlock). Use netlinkConnUnderLock instead in such situations. +func (cc *Conn) netlinkConn() (*netlink.Conn, netlinkCloser, error) { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.netlinkConnUnderLock() +} + +// netlinkConnUnderLock works like netlinkConn but must be called while holding +// the Conn.mu lock. +func (cc *Conn) netlinkConnUnderLock() (*netlink.Conn, netlinkCloser, error) { + if cc.nlconn != nil { + return cc.nlconn, func() error { return nil }, nil + } + nlconn, err := cc.dialNetlink() + if err != nil { + return nil, nil, err + } + return nlconn, func() error { return nlconn.Close() }, nil +} + +func receiveAckAware(nlconn *netlink.Conn, sentMsgFlags netlink.HeaderFlags) ([]netlink.Message, error) { + if nlconn == nil { + return nil, errors.New("netlink conn is not initialized") + } + + // first receive will be the message that we expect + reply, err := nlconn.Receive() + if err != nil { + return nil, err + } + + if (sentMsgFlags & netlink.Acknowledge) == 0 { + // we did not request an ack + return reply, nil + } + + if (sentMsgFlags & netlink.Dump) == netlink.Dump { + // sent message has Dump flag set, there will be no acks + // https://github.com/torvalds/linux/blob/7e062cda7d90543ac8c7700fc7c5527d0c0f22ad/net/netlink/af_netlink.c#L2387-L2390 + return reply, nil + } + + // Dump flag is not set, we expect an ack + ack, err := nlconn.Receive() + if err != nil { + return nil, err + } + + if len(ack) == 0 { + return nil, errors.New("received an empty ack") + } + + msg := ack[0] + if msg.Header.Type != netlink.Error { + // acks should be delivered as NLMSG_ERROR + return nil, fmt.Errorf("expected header %v, but got %v", netlink.Error, msg.Header.Type) + } + + if binaryutil.BigEndian.Uint32(msg.Data[:4]) != 0 { + // if errno field is not set to 0 (success), this is an error + return nil, fmt.Errorf("error delivered in message: %v", msg.Data) + } + + return reply, nil +} + +// CloseLasting closes the lasting netlink connection that has been opened using +// AsLasting option when creating this connection. If either no lasting netlink +// connection has been opened or the lasting connection is already in the +// process of closing or has been closed, CloseLasting will immediately return +// without any error. +// +// CloseLasting will terminate all pending netlink operations using the lasting +// connection. +// +// After closing a lasting connection, the connection will revert to using +// on-demand transient netlink connections when calling further netlink +// operations (such as GetTables). +func (cc *Conn) CloseLasting() error { + // Don't acquire the lock for the whole duration of the CloseLasting + // operation, but instead only so long as to make sure to only run the + // netlink socket close on the first time with a lasting netlink socket. As + // there is only the New() constructor, but no Open() method, it's + // impossible to reopen a lasting connection. + cc.mu.Lock() + nlconn := cc.nlconn + cc.nlconn = nil + cc.mu.Unlock() + if nlconn != nil { + return nlconn.Close() + } + return nil +} + +// Flush sends all buffered commands in a single batch to nftables. +func (cc *Conn) Flush() error { + cc.mu.Lock() + defer func() { + cc.messages = nil + cc.mu.Unlock() + }() + if len(cc.messages) == 0 { + // Messages were already programmed, returning nil + return nil + } + if cc.err != nil { + return cc.err // serialization error + } + conn, closer, err := cc.netlinkConnUnderLock() + if err != nil { + return err + } + defer func() { _ = closer() }() + + if _, err := conn.SendMessages(batch(cc.messages)); err != nil { + return fmt.Errorf("SendMessages: %w", err) + } + + // Fetch the requested acknowledgement for each message we sent. + for _, msg := range cc.messages { + if msg.Header.Flags&netlink.Acknowledge == 0 { + continue // message did not request an acknowledgement + } + if _, err := conn.Receive(); err != nil { + return fmt.Errorf("conn.Receive: %w", err) + } + } + + return nil +} + +// FlushRuleset flushes the entire ruleset. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Operations_at_ruleset_level +func (cc *Conn) FlushRuleset() { + cc.mu.Lock() + defer cc.mu.Unlock() + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), + Flags: netlink.Request | netlink.Acknowledge | netlink.Create, + }, + Data: extraHeader(0, 0), + }) +} + +func (cc *Conn) dialNetlink() (*netlink.Conn, error) { + if cc.TestDial != nil { + return nltest.Dial(cc.TestDial), nil + } + + return netlink.Dial(unix.NETLINK_NETFILTER, &netlink.Config{NetNS: cc.NetNS}) +} + +func (cc *Conn) setErr(err error) { + if cc.err != nil { + return + } + cc.err = err +} + +func (cc *Conn) marshalAttr(attrs []netlink.Attribute) []byte { + b, err := netlink.MarshalAttributes(attrs) + if err != nil { + cc.setErr(err) + return nil + } + return b +} + +func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte { + b, err := expr.Marshal(fam, e) + if err != nil { + cc.setErr(err) + return nil + } + return b +} + +func batch(messages []netlink.Message) []netlink.Message { + batch := []netlink.Message{ + { + Header: netlink.Header{ + Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), + Flags: netlink.Request, + }, + Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), + }, + } + + batch = append(batch, messages...) + + batch = append(batch, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END), + Flags: netlink.Request, + }, + Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), + }) + + return batch +} diff --git a/vendor/github.com/google/nftables/counter.go b/vendor/github.com/google/nftables/counter.go new file mode 100644 index 0000000000..e4282029be --- /dev/null +++ b/vendor/github.com/google/nftables/counter.go @@ -0,0 +1,70 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nftables + +import ( + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// CounterObj implements Obj. +type CounterObj struct { + Table *Table + Name string // e.g. “fwded” + + Bytes uint64 + Packets uint64 +} + +func (c *CounterObj) unmarshal(ad *netlink.AttributeDecoder) error { + for ad.Next() { + switch ad.Type() { + case unix.NFTA_COUNTER_BYTES: + c.Bytes = ad.Uint64() + case unix.NFTA_COUNTER_PACKETS: + c.Packets = ad.Uint64() + } + } + return ad.Err() +} + +func (c *CounterObj) table() *Table { + return c.Table +} + +func (c *CounterObj) family() TableFamily { + return c.Table.Family +} + +func (c *CounterObj) marshal(data bool) ([]byte, error) { + obj, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_COUNTER_BYTES, Data: binaryutil.BigEndian.PutUint64(c.Bytes)}, + {Type: unix.NFTA_COUNTER_PACKETS, Data: binaryutil.BigEndian.PutUint64(c.Packets)}, + }) + if err != nil { + return nil, err + } + const NFT_OBJECT_COUNTER = 1 // TODO: get into x/sys/unix + attrs := []netlink.Attribute{ + {Type: unix.NFTA_OBJ_TABLE, Data: []byte(c.Table.Name + "\x00")}, + {Type: unix.NFTA_OBJ_NAME, Data: []byte(c.Name + "\x00")}, + {Type: unix.NFTA_OBJ_TYPE, Data: binaryutil.BigEndian.PutUint32(NFT_OBJECT_COUNTER)}, + } + if data { + attrs = append(attrs, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA, Data: obj}) + } + return netlink.MarshalAttributes(attrs) +} diff --git a/vendor/github.com/google/nftables/doc.go b/vendor/github.com/google/nftables/doc.go new file mode 100644 index 0000000000..41985b35e9 --- /dev/null +++ b/vendor/github.com/google/nftables/doc.go @@ -0,0 +1,16 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package nftables manipulates Linux nftables (the iptables successor). +package nftables diff --git a/vendor/github.com/google/nftables/expr/bitwise.go b/vendor/github.com/google/nftables/expr/bitwise.go new file mode 100644 index 0000000000..62f7f9bae6 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/bitwise.go @@ -0,0 +1,102 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type Bitwise struct { + SourceRegister uint32 + DestRegister uint32 + Len uint32 + Mask []byte + Xor []byte +} + +func (e *Bitwise) marshal(fam byte) ([]byte, error) { + mask, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_DATA_VALUE, Data: e.Mask}, + }) + if err != nil { + return nil, err + } + xor, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_DATA_VALUE, Data: e.Xor}, + }) + if err != nil { + return nil, err + } + + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_BITWISE_SREG, Data: binaryutil.BigEndian.PutUint32(e.SourceRegister)}, + {Type: unix.NFTA_BITWISE_DREG, Data: binaryutil.BigEndian.PutUint32(e.DestRegister)}, + {Type: unix.NFTA_BITWISE_LEN, Data: binaryutil.BigEndian.PutUint32(e.Len)}, + {Type: unix.NLA_F_NESTED | unix.NFTA_BITWISE_MASK, Data: mask}, + {Type: unix.NLA_F_NESTED | unix.NFTA_BITWISE_XOR, Data: xor}, + }) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("bitwise\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Bitwise) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_BITWISE_SREG: + e.SourceRegister = ad.Uint32() + case unix.NFTA_BITWISE_DREG: + e.DestRegister = ad.Uint32() + case unix.NFTA_BITWISE_LEN: + e.Len = ad.Uint32() + case unix.NFTA_BITWISE_MASK: + // Since NFTA_BITWISE_MASK is nested, it requires additional decoding + ad.Nested(func(nad *netlink.AttributeDecoder) error { + for nad.Next() { + switch nad.Type() { + case unix.NFTA_DATA_VALUE: + e.Mask = nad.Bytes() + } + } + return nil + }) + case unix.NFTA_BITWISE_XOR: + // Since NFTA_BITWISE_XOR is nested, it requires additional decoding + ad.Nested(func(nad *netlink.AttributeDecoder) error { + for nad.Next() { + switch nad.Type() { + case unix.NFTA_DATA_VALUE: + e.Xor = nad.Bytes() + } + } + return nil + }) + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/byteorder.go b/vendor/github.com/google/nftables/expr/byteorder.go new file mode 100644 index 0000000000..2450e8f8fe --- /dev/null +++ b/vendor/github.com/google/nftables/expr/byteorder.go @@ -0,0 +1,59 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "fmt" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type ByteorderOp uint32 + +const ( + ByteorderNtoh ByteorderOp = unix.NFT_BYTEORDER_NTOH + ByteorderHton ByteorderOp = unix.NFT_BYTEORDER_HTON +) + +type Byteorder struct { + SourceRegister uint32 + DestRegister uint32 + Op ByteorderOp + Len uint32 + Size uint32 +} + +func (e *Byteorder) marshal(fam byte) ([]byte, error) { + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_BYTEORDER_SREG, Data: binaryutil.BigEndian.PutUint32(e.SourceRegister)}, + {Type: unix.NFTA_BYTEORDER_DREG, Data: binaryutil.BigEndian.PutUint32(e.DestRegister)}, + {Type: unix.NFTA_BYTEORDER_OP, Data: binaryutil.BigEndian.PutUint32(uint32(e.Op))}, + {Type: unix.NFTA_BYTEORDER_LEN, Data: binaryutil.BigEndian.PutUint32(e.Len)}, + {Type: unix.NFTA_BYTEORDER_SIZE, Data: binaryutil.BigEndian.PutUint32(e.Size)}, + }) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("byteorder\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Byteorder) unmarshal(fam byte, data []byte) error { + return fmt.Errorf("not yet implemented") +} diff --git a/vendor/github.com/google/nftables/expr/connlimit.go b/vendor/github.com/google/nftables/expr/connlimit.go new file mode 100644 index 0000000000..b712967a3f --- /dev/null +++ b/vendor/github.com/google/nftables/expr/connlimit.go @@ -0,0 +1,70 @@ +// Copyright 2019 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +const ( + // Per https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=84d12cfacf8ddd857a09435f3d982ab6250d250c#n1167 + NFTA_CONNLIMIT_UNSPEC = iota + NFTA_CONNLIMIT_COUNT + NFTA_CONNLIMIT_FLAGS + NFT_CONNLIMIT_F_INV = 1 +) + +// Per https://git.netfilter.org/libnftnl/tree/src/expr/connlimit.c?id=84d12cfacf8ddd857a09435f3d982ab6250d250c +type Connlimit struct { + Count uint32 + Flags uint32 +} + +func (e *Connlimit) marshal(fam byte) ([]byte, error) { + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: NFTA_CONNLIMIT_COUNT, Data: binaryutil.BigEndian.PutUint32(e.Count)}, + {Type: NFTA_CONNLIMIT_FLAGS, Data: binaryutil.BigEndian.PutUint32(e.Flags)}, + }) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("connlimit\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Connlimit) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case NFTA_CONNLIMIT_COUNT: + e.Count = binaryutil.BigEndian.Uint32(ad.Bytes()) + case NFTA_CONNLIMIT_FLAGS: + e.Flags = binaryutil.BigEndian.Uint32(ad.Bytes()) + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/counter.go b/vendor/github.com/google/nftables/expr/counter.go new file mode 100644 index 0000000000..dd6eab3f4e --- /dev/null +++ b/vendor/github.com/google/nftables/expr/counter.go @@ -0,0 +1,60 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type Counter struct { + Bytes uint64 + Packets uint64 +} + +func (e *Counter) marshal(fam byte) ([]byte, error) { + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_COUNTER_BYTES, Data: binaryutil.BigEndian.PutUint64(e.Bytes)}, + {Type: unix.NFTA_COUNTER_PACKETS, Data: binaryutil.BigEndian.PutUint64(e.Packets)}, + }) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("counter\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Counter) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_COUNTER_BYTES: + e.Bytes = ad.Uint64() + case unix.NFTA_COUNTER_PACKETS: + e.Packets = ad.Uint64() + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/ct.go b/vendor/github.com/google/nftables/expr/ct.go new file mode 100644 index 0000000000..1a0ee68b46 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/ct.go @@ -0,0 +1,115 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// CtKey specifies which piece of conntrack information should be loaded. See +// also https://wiki.nftables.org/wiki-nftables/index.php/Matching_connection_tracking_stateful_metainformation +type CtKey uint32 + +// Possible CtKey values. +const ( + CtKeySTATE CtKey = unix.NFT_CT_STATE + CtKeyDIRECTION CtKey = unix.NFT_CT_DIRECTION + CtKeySTATUS CtKey = unix.NFT_CT_STATUS + CtKeyMARK CtKey = unix.NFT_CT_MARK + CtKeySECMARK CtKey = unix.NFT_CT_SECMARK + CtKeyEXPIRATION CtKey = unix.NFT_CT_EXPIRATION + CtKeyHELPER CtKey = unix.NFT_CT_HELPER + CtKeyL3PROTOCOL CtKey = unix.NFT_CT_L3PROTOCOL + CtKeySRC CtKey = unix.NFT_CT_SRC + CtKeyDST CtKey = unix.NFT_CT_DST + CtKeyPROTOCOL CtKey = unix.NFT_CT_PROTOCOL + CtKeyPROTOSRC CtKey = unix.NFT_CT_PROTO_SRC + CtKeyPROTODST CtKey = unix.NFT_CT_PROTO_DST + CtKeyLABELS CtKey = unix.NFT_CT_LABELS + CtKeyPKTS CtKey = unix.NFT_CT_PKTS + CtKeyBYTES CtKey = unix.NFT_CT_BYTES + CtKeyAVGPKT CtKey = unix.NFT_CT_AVGPKT + CtKeyZONE CtKey = unix.NFT_CT_ZONE + CtKeyEVENTMASK CtKey = unix.NFT_CT_EVENTMASK + + // https://sources.debian.org/src//nftables/0.9.8-3/src/ct.c/?hl=39#L39 + CtStateBitINVALID uint32 = 1 + CtStateBitESTABLISHED uint32 = 2 + CtStateBitRELATED uint32 = 4 + CtStateBitNEW uint32 = 8 + CtStateBitUNTRACKED uint32 = 64 +) + +// Ct defines type for NFT connection tracking +type Ct struct { + Register uint32 + SourceRegister bool + Key CtKey +} + +func (e *Ct) marshal(fam byte) ([]byte, error) { + regData := []byte{} + exprData, err := netlink.MarshalAttributes( + []netlink.Attribute{ + {Type: unix.NFTA_CT_KEY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Key))}, + }, + ) + if err != nil { + return nil, err + } + if e.SourceRegister { + regData, err = netlink.MarshalAttributes( + []netlink.Attribute{ + {Type: unix.NFTA_CT_SREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, + }, + ) + } else { + regData, err = netlink.MarshalAttributes( + []netlink.Attribute{ + {Type: unix.NFTA_CT_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, + }, + ) + } + if err != nil { + return nil, err + } + exprData = append(exprData, regData...) + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("ct\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, + }) +} + +func (e *Ct) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_CT_KEY: + e.Key = CtKey(ad.Uint32()) + case unix.NFTA_CT_DREG: + e.Register = ad.Uint32() + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/dup.go b/vendor/github.com/google/nftables/expr/dup.go new file mode 100644 index 0000000000..0114fa796b --- /dev/null +++ b/vendor/github.com/google/nftables/expr/dup.go @@ -0,0 +1,67 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type Dup struct { + RegAddr uint32 + RegDev uint32 + IsRegDevSet bool +} + +func (e *Dup) marshal(fam byte) ([]byte, error) { + attrs := []netlink.Attribute{ + {Type: unix.NFTA_DUP_SREG_ADDR, Data: binaryutil.BigEndian.PutUint32(e.RegAddr)}, + } + + if e.IsRegDevSet { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_DUP_SREG_DEV, Data: binaryutil.BigEndian.PutUint32(e.RegDev)}) + } + + data, err := netlink.MarshalAttributes(attrs) + + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("dup\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Dup) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_DUP_SREG_ADDR: + e.RegAddr = ad.Uint32() + case unix.NFTA_DUP_SREG_DEV: + e.RegDev = ad.Uint32() + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/dynset.go b/vendor/github.com/google/nftables/expr/dynset.go new file mode 100644 index 0000000000..e44f772773 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/dynset.go @@ -0,0 +1,149 @@ +// Copyright 2020 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + "time" + + "github.com/google/nftables/binaryutil" + "github.com/google/nftables/internal/parseexprfunc" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// Not yet supported by unix package +// https://cs.opensource.google/go/x/sys/+/c6bc011c:unix/ztypes_linux.go;l=2027-2036 +const ( + NFTA_DYNSET_EXPRESSIONS = 0xa + NFT_DYNSET_F_EXPR = (1 << 1) +) + +// Dynset represent a rule dynamically adding or updating a set or a map based on an incoming packet. +type Dynset struct { + SrcRegKey uint32 + SrcRegData uint32 + SetID uint32 + SetName string + Operation uint32 + Timeout time.Duration + Invert bool + Exprs []Any +} + +func (e *Dynset) marshal(fam byte) ([]byte, error) { + // See: https://git.netfilter.org/libnftnl/tree/src/expr/dynset.c + var opAttrs []netlink.Attribute + opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_SREG_KEY, Data: binaryutil.BigEndian.PutUint32(e.SrcRegKey)}) + if e.SrcRegData != 0 { + opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_SREG_DATA, Data: binaryutil.BigEndian.PutUint32(e.SrcRegData)}) + } + opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_OP, Data: binaryutil.BigEndian.PutUint32(e.Operation)}) + if e.Timeout != 0 { + opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_TIMEOUT, Data: binaryutil.BigEndian.PutUint64(uint64(e.Timeout.Milliseconds()))}) + } + var flags uint32 + if e.Invert { + flags |= unix.NFT_DYNSET_F_INV + } + + opAttrs = append(opAttrs, + netlink.Attribute{Type: unix.NFTA_DYNSET_SET_NAME, Data: []byte(e.SetName + "\x00")}, + netlink.Attribute{Type: unix.NFTA_DYNSET_SET_ID, Data: binaryutil.BigEndian.PutUint32(e.SetID)}) + + // Per https://git.netfilter.org/libnftnl/tree/src/expr/dynset.c?id=84d12cfacf8ddd857a09435f3d982ab6250d250c#n170 + if len(e.Exprs) > 0 { + flags |= NFT_DYNSET_F_EXPR + switch len(e.Exprs) { + case 1: + exprData, err := Marshal(fam, e.Exprs[0]) + if err != nil { + return nil, err + } + opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_EXPR, Data: exprData}) + default: + var elemAttrs []netlink.Attribute + for _, ex := range e.Exprs { + exprData, err := Marshal(fam, ex) + if err != nil { + return nil, err + } + elemAttrs = append(elemAttrs, netlink.Attribute{Type: unix.NFTA_LIST_ELEM, Data: exprData}) + } + elemData, err := netlink.MarshalAttributes(elemAttrs) + if err != nil { + return nil, err + } + opAttrs = append(opAttrs, netlink.Attribute{Type: NFTA_DYNSET_EXPRESSIONS, Data: elemData}) + } + } + opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)}) + + opData, err := netlink.MarshalAttributes(opAttrs) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("dynset\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: opData}, + }) +} + +func (e *Dynset) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_DYNSET_SET_NAME: + e.SetName = ad.String() + case unix.NFTA_DYNSET_SET_ID: + e.SetID = ad.Uint32() + case unix.NFTA_DYNSET_SREG_KEY: + e.SrcRegKey = ad.Uint32() + case unix.NFTA_DYNSET_SREG_DATA: + e.SrcRegData = ad.Uint32() + case unix.NFTA_DYNSET_OP: + e.Operation = ad.Uint32() + case unix.NFTA_DYNSET_TIMEOUT: + e.Timeout = time.Duration(time.Millisecond * time.Duration(ad.Uint64())) + case unix.NFTA_DYNSET_FLAGS: + e.Invert = (ad.Uint32() & unix.NFT_DYNSET_F_INV) != 0 + case unix.NFTA_DYNSET_EXPR: + exprs, err := parseexprfunc.ParseExprBytesFunc(fam, ad, ad.Bytes()) + if err != nil { + return err + } + e.setInterfaceExprs(exprs) + case NFTA_DYNSET_EXPRESSIONS: + exprs, err := parseexprfunc.ParseExprMsgFunc(fam, ad.Bytes()) + if err != nil { + return err + } + e.setInterfaceExprs(exprs) + } + } + return ad.Err() +} + +func (e *Dynset) setInterfaceExprs(exprs []interface{}) { + e.Exprs = make([]Any, len(exprs)) + for i := range exprs { + e.Exprs[i] = exprs[i].(Any) + } +} diff --git a/vendor/github.com/google/nftables/expr/expr.go b/vendor/github.com/google/nftables/expr/expr.go new file mode 100644 index 0000000000..9a9ea7681d --- /dev/null +++ b/vendor/github.com/google/nftables/expr/expr.go @@ -0,0 +1,427 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package expr provides nftables rule expressions. +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/google/nftables/internal/parseexprfunc" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +func init() { + parseexprfunc.ParseExprBytesFunc = func(fam byte, ad *netlink.AttributeDecoder, b []byte) ([]interface{}, error) { + exprs, err := exprsFromBytes(fam, ad, b) + if err != nil { + return nil, err + } + result := make([]interface{}, len(exprs)) + for idx, expr := range exprs { + result[idx] = expr + } + return result, nil + } + parseexprfunc.ParseExprMsgFunc = func(fam byte, b []byte) ([]interface{}, error) { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return nil, err + } + ad.ByteOrder = binary.BigEndian + var exprs []interface{} + for ad.Next() { + e, err := parseexprfunc.ParseExprBytesFunc(fam, ad, b) + if err != nil { + return e, err + } + exprs = append(exprs, e...) + } + return exprs, ad.Err() + } +} + +// Marshal serializes the specified expression into a byte slice. +func Marshal(fam byte, e Any) ([]byte, error) { + return e.marshal(fam) +} + +// Unmarshal fills an expression from the specified byte slice. +func Unmarshal(fam byte, data []byte, e Any) error { + return e.unmarshal(fam, data) +} + +// exprsFromBytes parses nested raw expressions bytes +// to construct nftables expressions +func exprsFromBytes(fam byte, ad *netlink.AttributeDecoder, b []byte) ([]Any, error) { + var exprs []Any + ad.Do(func(b []byte) error { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + var name string + for ad.Next() { + switch ad.Type() { + case unix.NFTA_EXPR_NAME: + name = ad.String() + if name == "notrack" { + e := &Notrack{} + exprs = append(exprs, e) + } + case unix.NFTA_EXPR_DATA: + var e Any + switch name { + case "ct": + e = &Ct{} + case "range": + e = &Range{} + case "meta": + e = &Meta{} + case "cmp": + e = &Cmp{} + case "counter": + e = &Counter{} + case "objref": + e = &Objref{} + case "payload": + e = &Payload{} + case "lookup": + e = &Lookup{} + case "immediate": + e = &Immediate{} + case "bitwise": + e = &Bitwise{} + case "redir": + e = &Redir{} + case "nat": + e = &NAT{} + case "limit": + e = &Limit{} + case "quota": + e = &Quota{} + case "dynset": + e = &Dynset{} + case "log": + e = &Log{} + case "exthdr": + e = &Exthdr{} + case "match": + e = &Match{} + case "target": + e = &Target{} + case "connlimit": + e = &Connlimit{} + case "queue": + e = &Queue{} + case "flow_offload": + e = &FlowOffload{} + case "reject": + e = &Reject{} + case "masq": + e = &Masq{} + } + if e == nil { + // TODO: introduce an opaque expression type so that users know + // something is here. + continue // unsupported expression type + } + + ad.Do(func(b []byte) error { + if err := Unmarshal(fam, b, e); err != nil { + return err + } + // Verdict expressions are a special-case of immediate expressions, so + // if the expression is an immediate writing nothing into the verdict + // register (invalid), re-parse it as a verdict expression. + if imm, isImmediate := e.(*Immediate); isImmediate && imm.Register == unix.NFT_REG_VERDICT && len(imm.Data) == 0 { + e = &Verdict{} + if err := Unmarshal(fam, b, e); err != nil { + return err + } + } + exprs = append(exprs, e) + return nil + }) + } + } + return ad.Err() + }) + return exprs, ad.Err() +} + +// Any is an interface implemented by any expression type. +type Any interface { + marshal(fam byte) ([]byte, error) + unmarshal(fam byte, data []byte) error +} + +// MetaKey specifies which piece of meta information should be loaded. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Matching_packet_metainformation +type MetaKey uint32 + +// Possible MetaKey values. +const ( + MetaKeyLEN MetaKey = unix.NFT_META_LEN + MetaKeyPROTOCOL MetaKey = unix.NFT_META_PROTOCOL + MetaKeyPRIORITY MetaKey = unix.NFT_META_PRIORITY + MetaKeyMARK MetaKey = unix.NFT_META_MARK + MetaKeyIIF MetaKey = unix.NFT_META_IIF + MetaKeyOIF MetaKey = unix.NFT_META_OIF + MetaKeyIIFNAME MetaKey = unix.NFT_META_IIFNAME + MetaKeyOIFNAME MetaKey = unix.NFT_META_OIFNAME + MetaKeyIIFTYPE MetaKey = unix.NFT_META_IIFTYPE + MetaKeyOIFTYPE MetaKey = unix.NFT_META_OIFTYPE + MetaKeySKUID MetaKey = unix.NFT_META_SKUID + MetaKeySKGID MetaKey = unix.NFT_META_SKGID + MetaKeyNFTRACE MetaKey = unix.NFT_META_NFTRACE + MetaKeyRTCLASSID MetaKey = unix.NFT_META_RTCLASSID + MetaKeySECMARK MetaKey = unix.NFT_META_SECMARK + MetaKeyNFPROTO MetaKey = unix.NFT_META_NFPROTO + MetaKeyL4PROTO MetaKey = unix.NFT_META_L4PROTO + MetaKeyBRIIIFNAME MetaKey = unix.NFT_META_BRI_IIFNAME + MetaKeyBRIOIFNAME MetaKey = unix.NFT_META_BRI_OIFNAME + MetaKeyPKTTYPE MetaKey = unix.NFT_META_PKTTYPE + MetaKeyCPU MetaKey = unix.NFT_META_CPU + MetaKeyIIFGROUP MetaKey = unix.NFT_META_IIFGROUP + MetaKeyOIFGROUP MetaKey = unix.NFT_META_OIFGROUP + MetaKeyCGROUP MetaKey = unix.NFT_META_CGROUP + MetaKeyPRANDOM MetaKey = unix.NFT_META_PRANDOM +) + +// Meta loads packet meta information for later comparisons. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Matching_packet_metainformation +type Meta struct { + Key MetaKey + SourceRegister bool + Register uint32 +} + +func (e *Meta) marshal(fam byte) ([]byte, error) { + regData := []byte{} + exprData, err := netlink.MarshalAttributes( + []netlink.Attribute{ + {Type: unix.NFTA_META_KEY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Key))}, + }, + ) + if err != nil { + return nil, err + } + if e.SourceRegister { + regData, err = netlink.MarshalAttributes( + []netlink.Attribute{ + {Type: unix.NFTA_META_SREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, + }, + ) + } else { + regData, err = netlink.MarshalAttributes( + []netlink.Attribute{ + {Type: unix.NFTA_META_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, + }, + ) + } + if err != nil { + return nil, err + } + exprData = append(exprData, regData...) + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("meta\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, + }) +} + +func (e *Meta) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_META_SREG: + e.Register = ad.Uint32() + e.SourceRegister = true + case unix.NFTA_META_DREG: + e.Register = ad.Uint32() + case unix.NFTA_META_KEY: + e.Key = MetaKey(ad.Uint32()) + } + } + return ad.Err() +} + +// Masq (Masquerade) is a special case of SNAT, where the source address is +// automagically set to the address of the output interface. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Performing_Network_Address_Translation_(NAT)#Masquerading +type Masq struct { + Random bool + FullyRandom bool + Persistent bool + ToPorts bool + RegProtoMin uint32 + RegProtoMax uint32 +} + +// TODO, Once the constants below are available in golang.org/x/sys/unix, switch to use those. +const ( + // NF_NAT_RANGE_PROTO_RANDOM defines flag for a random masquerade + NF_NAT_RANGE_PROTO_RANDOM = 0x4 + // NF_NAT_RANGE_PROTO_RANDOM_FULLY defines flag for a fully random masquerade + NF_NAT_RANGE_PROTO_RANDOM_FULLY = 0x10 + // NF_NAT_RANGE_PERSISTENT defines flag for a persistent masquerade + NF_NAT_RANGE_PERSISTENT = 0x8 +) + +func (e *Masq) marshal(fam byte) ([]byte, error) { + msgData := []byte{} + if !e.ToPorts { + flags := uint32(0) + if e.Random { + flags |= NF_NAT_RANGE_PROTO_RANDOM + } + if e.FullyRandom { + flags |= NF_NAT_RANGE_PROTO_RANDOM_FULLY + } + if e.Persistent { + flags |= NF_NAT_RANGE_PERSISTENT + } + if flags != 0 { + flagsData, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_MASQ_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)}}) + if err != nil { + return nil, err + } + msgData = append(msgData, flagsData...) + } + } else { + regsData, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_MASQ_REG_PROTO_MIN, Data: binaryutil.BigEndian.PutUint32(e.RegProtoMin)}}) + if err != nil { + return nil, err + } + msgData = append(msgData, regsData...) + if e.RegProtoMax != 0 { + regsData, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_MASQ_REG_PROTO_MAX, Data: binaryutil.BigEndian.PutUint32(e.RegProtoMax)}}) + if err != nil { + return nil, err + } + msgData = append(msgData, regsData...) + } + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("masq\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: msgData}, + }) +} + +func (e *Masq) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_MASQ_REG_PROTO_MIN: + e.ToPorts = true + e.RegProtoMin = ad.Uint32() + case unix.NFTA_MASQ_REG_PROTO_MAX: + e.RegProtoMax = ad.Uint32() + case unix.NFTA_MASQ_FLAGS: + flags := ad.Uint32() + e.Persistent = (flags & NF_NAT_RANGE_PERSISTENT) != 0 + e.Random = (flags & NF_NAT_RANGE_PROTO_RANDOM) != 0 + e.FullyRandom = (flags & NF_NAT_RANGE_PROTO_RANDOM_FULLY) != 0 + } + } + return ad.Err() +} + +// CmpOp specifies which type of comparison should be performed. +type CmpOp uint32 + +// Possible CmpOp values. +const ( + CmpOpEq CmpOp = unix.NFT_CMP_EQ + CmpOpNeq CmpOp = unix.NFT_CMP_NEQ + CmpOpLt CmpOp = unix.NFT_CMP_LT + CmpOpLte CmpOp = unix.NFT_CMP_LTE + CmpOpGt CmpOp = unix.NFT_CMP_GT + CmpOpGte CmpOp = unix.NFT_CMP_GTE +) + +// Cmp compares a register with the specified data. +type Cmp struct { + Op CmpOp + Register uint32 + Data []byte +} + +func (e *Cmp) marshal(fam byte) ([]byte, error) { + cmpData, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_DATA_VALUE, Data: e.Data}, + }) + if err != nil { + return nil, err + } + exprData, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_CMP_SREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, + {Type: unix.NFTA_CMP_OP, Data: binaryutil.BigEndian.PutUint32(uint32(e.Op))}, + {Type: unix.NLA_F_NESTED | unix.NFTA_CMP_DATA, Data: cmpData}, + }) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("cmp\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, + }) +} + +func (e *Cmp) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_CMP_SREG: + e.Register = ad.Uint32() + case unix.NFTA_CMP_OP: + e.Op = CmpOp(ad.Uint32()) + case unix.NFTA_CMP_DATA: + ad.Do(func(b []byte) error { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + if ad.Next() && ad.Type() == unix.NFTA_DATA_VALUE { + ad.Do(func(b []byte) error { + e.Data = b + return nil + }) + } + return ad.Err() + }) + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/exthdr.go b/vendor/github.com/google/nftables/expr/exthdr.go new file mode 100644 index 0000000000..df0c7db0ce --- /dev/null +++ b/vendor/github.com/google/nftables/expr/exthdr.go @@ -0,0 +1,102 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type ExthdrOp uint32 + +const ( + ExthdrOpIpv6 ExthdrOp = unix.NFT_EXTHDR_OP_IPV6 + ExthdrOpTcpopt ExthdrOp = unix.NFT_EXTHDR_OP_TCPOPT +) + +type Exthdr struct { + DestRegister uint32 + Type uint8 + Offset uint32 + Len uint32 + Flags uint32 + Op ExthdrOp + SourceRegister uint32 +} + +func (e *Exthdr) marshal(fam byte) ([]byte, error) { + var attr []netlink.Attribute + + // Operations are differentiated by the Op and whether the SourceRegister + // or DestRegister is set. Mixing them results in EOPNOTSUPP. + if e.SourceRegister != 0 { + attr = []netlink.Attribute{ + {Type: unix.NFTA_EXTHDR_SREG, Data: binaryutil.BigEndian.PutUint32(e.SourceRegister)}} + } else { + attr = []netlink.Attribute{ + {Type: unix.NFTA_EXTHDR_DREG, Data: binaryutil.BigEndian.PutUint32(e.DestRegister)}} + } + + attr = append(attr, + netlink.Attribute{Type: unix.NFTA_EXTHDR_TYPE, Data: []byte{e.Type}}, + netlink.Attribute{Type: unix.NFTA_EXTHDR_OFFSET, Data: binaryutil.BigEndian.PutUint32(e.Offset)}, + netlink.Attribute{Type: unix.NFTA_EXTHDR_LEN, Data: binaryutil.BigEndian.PutUint32(e.Len)}, + netlink.Attribute{Type: unix.NFTA_EXTHDR_OP, Data: binaryutil.BigEndian.PutUint32(uint32(e.Op))}) + + // Flags is only set if DREG is set + if e.DestRegister != 0 { + attr = append(attr, + netlink.Attribute{Type: unix.NFTA_EXTHDR_FLAGS, Data: binaryutil.BigEndian.PutUint32(e.Flags)}) + } + + data, err := netlink.MarshalAttributes(attr) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("exthdr\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Exthdr) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_EXTHDR_DREG: + e.DestRegister = ad.Uint32() + case unix.NFTA_EXTHDR_TYPE: + e.Type = ad.Uint8() + case unix.NFTA_EXTHDR_OFFSET: + e.Offset = ad.Uint32() + case unix.NFTA_EXTHDR_LEN: + e.Len = ad.Uint32() + case unix.NFTA_EXTHDR_FLAGS: + e.Flags = ad.Uint32() + case unix.NFTA_EXTHDR_OP: + e.Op = ExthdrOp(ad.Uint32()) + case unix.NFTA_EXTHDR_SREG: + e.SourceRegister = ad.Uint32() + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/fib.go b/vendor/github.com/google/nftables/expr/fib.go new file mode 100644 index 0000000000..f7ee7043a4 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/fib.go @@ -0,0 +1,128 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// Fib defines fib expression structure +type Fib struct { + Register uint32 + ResultOIF bool + ResultOIFNAME bool + ResultADDRTYPE bool + FlagSADDR bool + FlagDADDR bool + FlagMARK bool + FlagIIF bool + FlagOIF bool + FlagPRESENT bool +} + +func (e *Fib) marshal(fam byte) ([]byte, error) { + data := []byte{} + reg, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_FIB_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, + }) + if err != nil { + return nil, err + } + data = append(data, reg...) + flags := uint32(0) + if e.FlagSADDR { + flags |= unix.NFTA_FIB_F_SADDR + } + if e.FlagDADDR { + flags |= unix.NFTA_FIB_F_DADDR + } + if e.FlagMARK { + flags |= unix.NFTA_FIB_F_MARK + } + if e.FlagIIF { + flags |= unix.NFTA_FIB_F_IIF + } + if e.FlagOIF { + flags |= unix.NFTA_FIB_F_OIF + } + if e.FlagPRESENT { + flags |= unix.NFTA_FIB_F_PRESENT + } + if flags != 0 { + flg, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_FIB_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)}, + }) + if err != nil { + return nil, err + } + data = append(data, flg...) + } + results := uint32(0) + if e.ResultOIF { + results |= unix.NFT_FIB_RESULT_OIF + } + if e.ResultOIFNAME { + results |= unix.NFT_FIB_RESULT_OIFNAME + } + if e.ResultADDRTYPE { + results |= unix.NFT_FIB_RESULT_ADDRTYPE + } + if results != 0 { + rslt, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_FIB_RESULT, Data: binaryutil.BigEndian.PutUint32(results)}, + }) + if err != nil { + return nil, err + } + data = append(data, rslt...) + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("fib\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Fib) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_FIB_DREG: + e.Register = ad.Uint32() + case unix.NFTA_FIB_RESULT: + result := ad.Uint32() + e.ResultOIF = (result & unix.NFT_FIB_RESULT_OIF) == 1 + e.ResultOIFNAME = (result & unix.NFT_FIB_RESULT_OIFNAME) == 1 + e.ResultADDRTYPE = (result & unix.NFT_FIB_RESULT_ADDRTYPE) == 1 + case unix.NFTA_FIB_FLAGS: + flags := ad.Uint32() + e.FlagSADDR = (flags & unix.NFTA_FIB_F_SADDR) == 1 + e.FlagDADDR = (flags & unix.NFTA_FIB_F_DADDR) == 1 + e.FlagMARK = (flags & unix.NFTA_FIB_F_MARK) == 1 + e.FlagIIF = (flags & unix.NFTA_FIB_F_IIF) == 1 + e.FlagOIF = (flags & unix.NFTA_FIB_F_OIF) == 1 + e.FlagPRESENT = (flags & unix.NFTA_FIB_F_PRESENT) == 1 + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/flow_offload.go b/vendor/github.com/google/nftables/expr/flow_offload.go new file mode 100644 index 0000000000..54f956f1c9 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/flow_offload.go @@ -0,0 +1,59 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +const NFTNL_EXPR_FLOW_TABLE_NAME = 1 + +type FlowOffload struct { + Name string +} + +func (e *FlowOffload) marshal(fam byte) ([]byte, error) { + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: NFTNL_EXPR_FLOW_TABLE_NAME, Data: []byte(e.Name)}, + }) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("flow_offload\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *FlowOffload) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case NFTNL_EXPR_FLOW_TABLE_NAME: + e.Name = ad.String() + } + } + + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/hash.go b/vendor/github.com/google/nftables/expr/hash.go new file mode 100644 index 0000000000..68491770f2 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/hash.go @@ -0,0 +1,87 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type HashType uint32 + +const ( + HashTypeJenkins HashType = unix.NFT_HASH_JENKINS + HashTypeSym HashType = unix.NFT_HASH_SYM +) + +// Hash defines type for nftables internal hashing functions +type Hash struct { + SourceRegister uint32 + DestRegister uint32 + Length uint32 + Modulus uint32 + Seed uint32 + Offset uint32 + Type HashType +} + +func (e *Hash) marshal(fam byte) ([]byte, error) { + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_HASH_SREG, Data: binaryutil.BigEndian.PutUint32(uint32(e.SourceRegister))}, + {Type: unix.NFTA_HASH_DREG, Data: binaryutil.BigEndian.PutUint32(uint32(e.DestRegister))}, + {Type: unix.NFTA_HASH_LEN, Data: binaryutil.BigEndian.PutUint32(uint32(e.Length))}, + {Type: unix.NFTA_HASH_MODULUS, Data: binaryutil.BigEndian.PutUint32(uint32(e.Modulus))}, + {Type: unix.NFTA_HASH_SEED, Data: binaryutil.BigEndian.PutUint32(uint32(e.Seed))}, + {Type: unix.NFTA_HASH_OFFSET, Data: binaryutil.BigEndian.PutUint32(uint32(e.Offset))}, + {Type: unix.NFTA_HASH_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(e.Type))}, + }) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("hash\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Hash) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_HASH_SREG: + e.SourceRegister = ad.Uint32() + case unix.NFTA_HASH_DREG: + e.DestRegister = ad.Uint32() + case unix.NFTA_HASH_LEN: + e.Length = ad.Uint32() + case unix.NFTA_HASH_MODULUS: + e.Modulus = ad.Uint32() + case unix.NFTA_HASH_SEED: + e.Seed = ad.Uint32() + case unix.NFTA_HASH_OFFSET: + e.Offset = ad.Uint32() + case unix.NFTA_HASH_TYPE: + e.Type = HashType(ad.Uint32()) + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/immediate.go b/vendor/github.com/google/nftables/expr/immediate.go new file mode 100644 index 0000000000..99531f867d --- /dev/null +++ b/vendor/github.com/google/nftables/expr/immediate.go @@ -0,0 +1,79 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + "fmt" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type Immediate struct { + Register uint32 + Data []byte +} + +func (e *Immediate) marshal(fam byte) ([]byte, error) { + immData, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_DATA_VALUE, Data: e.Data}, + }) + if err != nil { + return nil, err + } + + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_IMMEDIATE_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, + {Type: unix.NLA_F_NESTED | unix.NFTA_IMMEDIATE_DATA, Data: immData}, + }) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("immediate\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Immediate) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_IMMEDIATE_DREG: + e.Register = ad.Uint32() + case unix.NFTA_IMMEDIATE_DATA: + nestedAD, err := netlink.NewAttributeDecoder(ad.Bytes()) + if err != nil { + return fmt.Errorf("nested NewAttributeDecoder() failed: %v", err) + } + for nestedAD.Next() { + switch nestedAD.Type() { + case unix.NFTA_DATA_VALUE: + e.Data = nestedAD.Bytes() + } + } + if nestedAD.Err() != nil { + return fmt.Errorf("decoding immediate: %v", nestedAD.Err()) + } + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/limit.go b/vendor/github.com/google/nftables/expr/limit.go new file mode 100644 index 0000000000..9ecb41f047 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/limit.go @@ -0,0 +1,128 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + "errors" + "fmt" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// LimitType represents the type of the limit expression. +type LimitType uint32 + +// Imported from the nft_limit_type enum in netfilter/nf_tables.h. +const ( + LimitTypePkts LimitType = unix.NFT_LIMIT_PKTS + LimitTypePktBytes LimitType = unix.NFT_LIMIT_PKT_BYTES +) + +// LimitTime represents the limit unit. +type LimitTime uint64 + +// Possible limit unit values. +const ( + LimitTimeSecond LimitTime = 1 + LimitTimeMinute LimitTime = 60 + LimitTimeHour LimitTime = 60 * 60 + LimitTimeDay LimitTime = 60 * 60 * 24 + LimitTimeWeek LimitTime = 60 * 60 * 24 * 7 +) + +func limitTime(value uint64) (LimitTime, error) { + switch LimitTime(value) { + case LimitTimeSecond: + return LimitTimeSecond, nil + case LimitTimeMinute: + return LimitTimeMinute, nil + case LimitTimeHour: + return LimitTimeHour, nil + case LimitTimeDay: + return LimitTimeDay, nil + case LimitTimeWeek: + return LimitTimeWeek, nil + default: + return 0, fmt.Errorf("expr: invalid limit unit value %d", value) + } +} + +// Limit represents a rate limit expression. +type Limit struct { + Type LimitType + Rate uint64 + Over bool + Unit LimitTime + Burst uint32 +} + +func (l *Limit) marshal(fam byte) ([]byte, error) { + var flags uint32 + if l.Over { + flags = unix.NFT_LIMIT_F_INV + } + attrs := []netlink.Attribute{ + {Type: unix.NFTA_LIMIT_RATE, Data: binaryutil.BigEndian.PutUint64(l.Rate)}, + {Type: unix.NFTA_LIMIT_UNIT, Data: binaryutil.BigEndian.PutUint64(uint64(l.Unit))}, + {Type: unix.NFTA_LIMIT_BURST, Data: binaryutil.BigEndian.PutUint32(l.Burst)}, + {Type: unix.NFTA_LIMIT_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(l.Type))}, + {Type: unix.NFTA_LIMIT_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)}, + } + + data, err := netlink.MarshalAttributes(attrs) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("limit\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (l *Limit) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_LIMIT_RATE: + l.Rate = ad.Uint64() + case unix.NFTA_LIMIT_UNIT: + l.Unit, err = limitTime(ad.Uint64()) + if err != nil { + return err + } + case unix.NFTA_LIMIT_BURST: + l.Burst = ad.Uint32() + case unix.NFTA_LIMIT_TYPE: + l.Type = LimitType(ad.Uint32()) + if l.Type != LimitTypePkts && l.Type != LimitTypePktBytes { + return fmt.Errorf("expr: invalid limit type %d", l.Type) + } + case unix.NFTA_LIMIT_FLAGS: + l.Over = (ad.Uint32() & unix.NFT_LIMIT_F_INV) == 1 + default: + return errors.New("expr: unhandled limit netlink attribute") + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/log.go b/vendor/github.com/google/nftables/expr/log.go new file mode 100644 index 0000000000..a712b990f2 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/log.go @@ -0,0 +1,150 @@ +// Copyright 2019 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type LogLevel uint32 + +const ( + // See https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h?id=5b364657a35f4e4cd5d220ba2a45303d729c8eca#n1226 + LogLevelEmerg LogLevel = iota + LogLevelAlert + LogLevelCrit + LogLevelErr + LogLevelWarning + LogLevelNotice + LogLevelInfo + LogLevelDebug + LogLevelAudit +) + +type LogFlags uint32 + +const ( + // See https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_log.h?id=5b364657a35f4e4cd5d220ba2a45303d729c8eca + LogFlagsTCPSeq LogFlags = 0x01 << iota + LogFlagsTCPOpt + LogFlagsIPOpt + LogFlagsUID + LogFlagsNFLog + LogFlagsMACDecode + LogFlagsMask LogFlags = 0x2f +) + +// Log defines type for NFT logging +// See https://git.netfilter.org/libnftnl/tree/src/expr/log.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n25 +type Log struct { + Level LogLevel + // Refers to log flags (flags all, flags ip options, ...) + Flags LogFlags + // Equivalent to expression flags. + // Indicates that an option is set by setting a bit + // on index referred by the NFTA_LOG_* value. + // See https://cs.opensource.google/go/x/sys/+/3681064d:unix/ztypes_linux.go;l=2126;drc=3681064d51587c1db0324b3d5c23c2ddbcff6e8f + Key uint32 + Snaplen uint32 + Group uint16 + QThreshold uint16 + // Log prefix string content + Data []byte +} + +func (e *Log) marshal(fam byte) ([]byte, error) { + // Per https://git.netfilter.org/libnftnl/tree/src/expr/log.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n129 + attrs := make([]netlink.Attribute, 0) + if e.Key&(1<= /* sic! */ XTablesExtensionNameMaxLen { + name = name[:XTablesExtensionNameMaxLen-1] // leave room for trailing \x00. + } + // Marshalling assumes that the correct Info type for the particular table + // family and Match revision has been set. + info, err := xt.Marshal(xt.TableFamily(fam), e.Rev, e.Info) + if err != nil { + return nil, err + } + attrs := []netlink.Attribute{ + {Type: unix.NFTA_MATCH_NAME, Data: []byte(name + "\x00")}, + {Type: unix.NFTA_MATCH_REV, Data: binaryutil.BigEndian.PutUint32(e.Rev)}, + {Type: unix.NFTA_MATCH_INFO, Data: info}, + } + data, err := netlink.MarshalAttributes(attrs) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("match\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Match) unmarshal(fam byte, data []byte) error { + // Per https://git.netfilter.org/libnftnl/tree/src/expr/match.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n65 + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + + var info []byte + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_MATCH_NAME: + // We are forgiving here, accepting any length and even missing terminating \x00. + e.Name = string(bytes.TrimRight(ad.Bytes(), "\x00")) + case unix.NFTA_MATCH_REV: + e.Rev = ad.Uint32() + case unix.NFTA_MATCH_INFO: + info = ad.Bytes() + } + } + if err = ad.Err(); err != nil { + return err + } + e.Info, err = xt.Unmarshal(e.Name, xt.TableFamily(fam), e.Rev, info) + return err +} diff --git a/vendor/github.com/google/nftables/expr/nat.go b/vendor/github.com/google/nftables/expr/nat.go new file mode 100644 index 0000000000..9602c233f2 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/nat.go @@ -0,0 +1,127 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type NATType uint32 + +// Possible NATType values. +const ( + NATTypeSourceNAT NATType = unix.NFT_NAT_SNAT + NATTypeDestNAT NATType = unix.NFT_NAT_DNAT +) + +type NAT struct { + Type NATType + Family uint32 // TODO: typed const + RegAddrMin uint32 + RegAddrMax uint32 + RegProtoMin uint32 + RegProtoMax uint32 + Random bool + FullyRandom bool + Persistent bool +} + +// |00048|N-|00001| |len |flags| type| +// |00008|--|00001| |len |flags| type| +// | 6e 61 74 00 | | data | n a t +// |00036|N-|00002| |len |flags| type| +// |00008|--|00001| |len |flags| type| NFTA_NAT_TYPE +// | 00 00 00 01 | | data | NFT_NAT_DNAT +// |00008|--|00002| |len |flags| type| NFTA_NAT_FAMILY +// | 00 00 00 02 | | data | NFPROTO_IPV4 +// |00008|--|00003| |len |flags| type| NFTA_NAT_REG_ADDR_MIN +// | 00 00 00 01 | | data | reg 1 +// |00008|--|00005| |len |flags| type| NFTA_NAT_REG_PROTO_MIN +// | 00 00 00 02 | | data | reg 2 + +func (e *NAT) marshal(fam byte) ([]byte, error) { + attrs := []netlink.Attribute{ + {Type: unix.NFTA_NAT_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(e.Type))}, + {Type: unix.NFTA_NAT_FAMILY, Data: binaryutil.BigEndian.PutUint32(e.Family)}, + } + if e.RegAddrMin != 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_NAT_REG_ADDR_MIN, Data: binaryutil.BigEndian.PutUint32(e.RegAddrMin)}) + if e.RegAddrMax != 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_NAT_REG_ADDR_MAX, Data: binaryutil.BigEndian.PutUint32(e.RegAddrMax)}) + } + } + if e.RegProtoMin != 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_NAT_REG_PROTO_MIN, Data: binaryutil.BigEndian.PutUint32(e.RegProtoMin)}) + if e.RegProtoMax != 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_NAT_REG_PROTO_MAX, Data: binaryutil.BigEndian.PutUint32(e.RegProtoMax)}) + } + } + flags := uint32(0) + if e.Random { + flags |= NF_NAT_RANGE_PROTO_RANDOM + } + if e.FullyRandom { + flags |= NF_NAT_RANGE_PROTO_RANDOM_FULLY + } + if e.Persistent { + flags |= NF_NAT_RANGE_PERSISTENT + } + if flags != 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_NAT_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)}) + } + + data, err := netlink.MarshalAttributes(attrs) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("nat\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *NAT) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_NAT_TYPE: + e.Type = NATType(ad.Uint32()) + case unix.NFTA_NAT_FAMILY: + e.Family = ad.Uint32() + case unix.NFTA_NAT_REG_ADDR_MIN: + e.RegAddrMin = ad.Uint32() + case unix.NFTA_NAT_REG_ADDR_MAX: + e.RegAddrMax = ad.Uint32() + case unix.NFTA_NAT_REG_PROTO_MIN: + e.RegProtoMin = ad.Uint32() + case unix.NFTA_NAT_REG_PROTO_MAX: + e.RegProtoMax = ad.Uint32() + case unix.NFTA_NAT_FLAGS: + flags := ad.Uint32() + e.Persistent = (flags & NF_NAT_RANGE_PERSISTENT) != 0 + e.Random = (flags & NF_NAT_RANGE_PROTO_RANDOM) != 0 + e.FullyRandom = (flags & NF_NAT_RANGE_PROTO_RANDOM_FULLY) != 0 + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/notrack.go b/vendor/github.com/google/nftables/expr/notrack.go new file mode 100644 index 0000000000..cb665d363c --- /dev/null +++ b/vendor/github.com/google/nftables/expr/notrack.go @@ -0,0 +1,38 @@ +// Copyright 2019 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type Notrack struct{} + +func (e *Notrack) marshal(fam byte) ([]byte, error) { + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("notrack\x00")}, + }) +} + +func (e *Notrack) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + + if err != nil { + return err + } + + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/numgen.go b/vendor/github.com/google/nftables/expr/numgen.go new file mode 100644 index 0000000000..bcbb1bbeb9 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/numgen.go @@ -0,0 +1,78 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + "fmt" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// Numgen defines Numgen expression structure +type Numgen struct { + Register uint32 + Modulus uint32 + Type uint32 + Offset uint32 +} + +func (e *Numgen) marshal(fam byte) ([]byte, error) { + // Currently only two types are supported, failing if Type is not of two known types + switch e.Type { + case unix.NFT_NG_INCREMENTAL: + case unix.NFT_NG_RANDOM: + default: + return nil, fmt.Errorf("unsupported numgen type %d", e.Type) + } + + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_NG_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, + {Type: unix.NFTA_NG_MODULUS, Data: binaryutil.BigEndian.PutUint32(e.Modulus)}, + {Type: unix.NFTA_NG_TYPE, Data: binaryutil.BigEndian.PutUint32(e.Type)}, + {Type: unix.NFTA_NG_OFFSET, Data: binaryutil.BigEndian.PutUint32(e.Offset)}, + }) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("numgen\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Numgen) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_NG_DREG: + e.Register = ad.Uint32() + case unix.NFTA_NG_MODULUS: + e.Modulus = ad.Uint32() + case unix.NFTA_NG_TYPE: + e.Type = ad.Uint32() + case unix.NFTA_NG_OFFSET: + e.Offset = ad.Uint32() + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/objref.go b/vendor/github.com/google/nftables/expr/objref.go new file mode 100644 index 0000000000..ae9521b91f --- /dev/null +++ b/vendor/github.com/google/nftables/expr/objref.go @@ -0,0 +1,60 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type Objref struct { + Type int // TODO: enum + Name string +} + +func (e *Objref) marshal(fam byte) ([]byte, error) { + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_OBJREF_IMM_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(e.Type))}, + {Type: unix.NFTA_OBJREF_IMM_NAME, Data: []byte(e.Name)}, // NOT \x00-terminated?! + }) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("objref\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Objref) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_OBJREF_IMM_TYPE: + e.Type = int(ad.Uint32()) + case unix.NFTA_OBJREF_IMM_NAME: + e.Name = ad.String() + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/payload.go b/vendor/github.com/google/nftables/expr/payload.go new file mode 100644 index 0000000000..7f698095ca --- /dev/null +++ b/vendor/github.com/google/nftables/expr/payload.go @@ -0,0 +1,131 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type PayloadBase uint32 +type PayloadCsumType uint32 +type PayloadOperationType uint32 + +// Possible PayloadBase values. +const ( + PayloadBaseLLHeader PayloadBase = unix.NFT_PAYLOAD_LL_HEADER + PayloadBaseNetworkHeader PayloadBase = unix.NFT_PAYLOAD_NETWORK_HEADER + PayloadBaseTransportHeader PayloadBase = unix.NFT_PAYLOAD_TRANSPORT_HEADER +) + +// Possible PayloadCsumType values. +const ( + CsumTypeNone PayloadCsumType = unix.NFT_PAYLOAD_CSUM_NONE + CsumTypeInet PayloadCsumType = unix.NFT_PAYLOAD_CSUM_INET +) + +// Possible PayloadOperationType values. +const ( + PayloadLoad PayloadOperationType = iota + PayloadWrite +) + +type Payload struct { + OperationType PayloadOperationType + DestRegister uint32 + SourceRegister uint32 + Base PayloadBase + Offset uint32 + Len uint32 + CsumType PayloadCsumType + CsumOffset uint32 + CsumFlags uint32 +} + +func (e *Payload) marshal(fam byte) ([]byte, error) { + + var attrs []netlink.Attribute + + if e.OperationType == PayloadWrite { + attrs = []netlink.Attribute{ + {Type: unix.NFTA_PAYLOAD_SREG, Data: binaryutil.BigEndian.PutUint32(e.SourceRegister)}, + } + } else { + attrs = []netlink.Attribute{ + {Type: unix.NFTA_PAYLOAD_DREG, Data: binaryutil.BigEndian.PutUint32(e.DestRegister)}, + } + } + + attrs = append(attrs, + netlink.Attribute{Type: unix.NFTA_PAYLOAD_BASE, Data: binaryutil.BigEndian.PutUint32(uint32(e.Base))}, + netlink.Attribute{Type: unix.NFTA_PAYLOAD_OFFSET, Data: binaryutil.BigEndian.PutUint32(e.Offset)}, + netlink.Attribute{Type: unix.NFTA_PAYLOAD_LEN, Data: binaryutil.BigEndian.PutUint32(e.Len)}, + ) + + if e.CsumType > 0 { + attrs = append(attrs, + netlink.Attribute{Type: unix.NFTA_PAYLOAD_CSUM_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(e.CsumType))}, + netlink.Attribute{Type: unix.NFTA_PAYLOAD_CSUM_OFFSET, Data: binaryutil.BigEndian.PutUint32(uint32(e.CsumOffset))}, + ) + if e.CsumFlags > 0 { + attrs = append(attrs, + netlink.Attribute{Type: unix.NFTA_PAYLOAD_CSUM_FLAGS, Data: binaryutil.BigEndian.PutUint32(e.CsumFlags)}, + ) + } + } + + data, err := netlink.MarshalAttributes(attrs) + + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("payload\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Payload) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_PAYLOAD_DREG: + e.DestRegister = ad.Uint32() + case unix.NFTA_PAYLOAD_SREG: + e.SourceRegister = ad.Uint32() + e.OperationType = PayloadWrite + case unix.NFTA_PAYLOAD_BASE: + e.Base = PayloadBase(ad.Uint32()) + case unix.NFTA_PAYLOAD_OFFSET: + e.Offset = ad.Uint32() + case unix.NFTA_PAYLOAD_LEN: + e.Len = ad.Uint32() + case unix.NFTA_PAYLOAD_CSUM_TYPE: + e.CsumType = PayloadCsumType(ad.Uint32()) + case unix.NFTA_PAYLOAD_CSUM_OFFSET: + e.CsumOffset = ad.Uint32() + case unix.NFTA_PAYLOAD_CSUM_FLAGS: + e.CsumFlags = ad.Uint32() + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/queue.go b/vendor/github.com/google/nftables/expr/queue.go new file mode 100644 index 0000000000..3d0012dae9 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/queue.go @@ -0,0 +1,82 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type QueueAttribute uint16 + +type QueueFlag uint16 + +// Possible QueueAttribute values +const ( + QueueNum QueueAttribute = unix.NFTA_QUEUE_NUM + QueueTotal QueueAttribute = unix.NFTA_QUEUE_TOTAL + QueueFlags QueueAttribute = unix.NFTA_QUEUE_FLAGS + + // TODO: get into x/sys/unix + QueueFlagBypass QueueFlag = 0x01 + QueueFlagFanout QueueFlag = 0x02 + QueueFlagMask QueueFlag = 0x03 +) + +type Queue struct { + Num uint16 + Total uint16 + Flag QueueFlag +} + +func (e *Queue) marshal(fam byte) ([]byte, error) { + if e.Total == 0 { + e.Total = 1 // The total default value is 1 + } + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_QUEUE_NUM, Data: binaryutil.BigEndian.PutUint16(e.Num)}, + {Type: unix.NFTA_QUEUE_TOTAL, Data: binaryutil.BigEndian.PutUint16(e.Total)}, + {Type: unix.NFTA_QUEUE_FLAGS, Data: binaryutil.BigEndian.PutUint16(uint16(e.Flag))}, + }) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("queue\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Queue) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_QUEUE_NUM: + e.Num = ad.Uint16() + case unix.NFTA_QUEUE_TOTAL: + e.Total = ad.Uint16() + case unix.NFTA_QUEUE_FLAGS: + e.Flag = QueueFlag(ad.Uint16()) + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/quota.go b/vendor/github.com/google/nftables/expr/quota.go new file mode 100644 index 0000000000..f8bc0f30dd --- /dev/null +++ b/vendor/github.com/google/nftables/expr/quota.go @@ -0,0 +1,76 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// Quota defines a threshold against a number of bytes. +type Quota struct { + Bytes uint64 + Consumed uint64 + Over bool +} + +func (q *Quota) marshal(fam byte) ([]byte, error) { + attrs := []netlink.Attribute{ + {Type: unix.NFTA_QUOTA_BYTES, Data: binaryutil.BigEndian.PutUint64(q.Bytes)}, + {Type: unix.NFTA_QUOTA_CONSUMED, Data: binaryutil.BigEndian.PutUint64(q.Consumed)}, + } + + flags := uint32(0) + if q.Over { + flags = unix.NFT_QUOTA_F_INV + } + attrs = append(attrs, netlink.Attribute{ + Type: unix.NFTA_QUOTA_FLAGS, + Data: binaryutil.BigEndian.PutUint32(flags), + }) + + data, err := netlink.MarshalAttributes(attrs) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("quota\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (q *Quota) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_QUOTA_BYTES: + q.Bytes = ad.Uint64() + case unix.NFTA_QUOTA_CONSUMED: + q.Consumed = ad.Uint64() + case unix.NFTA_QUOTA_FLAGS: + q.Over = (ad.Uint32() & unix.NFT_QUOTA_F_INV) == 1 + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/range.go b/vendor/github.com/google/nftables/expr/range.go new file mode 100644 index 0000000000..8a1f6ea184 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/range.go @@ -0,0 +1,124 @@ +// Copyright 2019 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// Range implements range expression +type Range struct { + Op CmpOp + Register uint32 + FromData []byte + ToData []byte +} + +func (e *Range) marshal(fam byte) ([]byte, error) { + var attrs []netlink.Attribute + var err error + var rangeFromData, rangeToData []byte + + if e.Register > 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_RANGE_SREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}) + } + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_RANGE_OP, Data: binaryutil.BigEndian.PutUint32(uint32(e.Op))}) + if len(e.FromData) > 0 { + rangeFromData, err = nestedAttr(e.FromData, unix.NFTA_RANGE_FROM_DATA) + if err != nil { + return nil, err + } + } + if len(e.ToData) > 0 { + rangeToData, err = nestedAttr(e.ToData, unix.NFTA_RANGE_TO_DATA) + if err != nil { + return nil, err + } + } + data, err := netlink.MarshalAttributes(attrs) + if err != nil { + return nil, err + } + data = append(data, rangeFromData...) + data = append(data, rangeToData...) + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("range\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Range) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_RANGE_OP: + e.Op = CmpOp(ad.Uint32()) + case unix.NFTA_RANGE_SREG: + e.Register = ad.Uint32() + case unix.NFTA_RANGE_FROM_DATA: + ad.Do(func(b []byte) error { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + if ad.Next() && ad.Type() == unix.NFTA_DATA_VALUE { + ad.Do(func(b []byte) error { + e.FromData = b + return nil + }) + } + return ad.Err() + }) + case unix.NFTA_RANGE_TO_DATA: + ad.Do(func(b []byte) error { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + if ad.Next() && ad.Type() == unix.NFTA_DATA_VALUE { + ad.Do(func(b []byte) error { + e.ToData = b + return nil + }) + } + return ad.Err() + }) + } + } + return ad.Err() +} + +func nestedAttr(data []byte, attrType uint16) ([]byte, error) { + ae := netlink.NewAttributeEncoder() + ae.Do(unix.NLA_F_NESTED|attrType, func() ([]byte, error) { + nae := netlink.NewAttributeEncoder() + nae.ByteOrder = binary.BigEndian + nae.Bytes(unix.NFTA_DATA_VALUE, data) + + return nae.Encode() + }) + return ae.Encode() +} diff --git a/vendor/github.com/google/nftables/expr/redirect.go b/vendor/github.com/google/nftables/expr/redirect.go new file mode 100644 index 0000000000..1c6f62213b --- /dev/null +++ b/vendor/github.com/google/nftables/expr/redirect.go @@ -0,0 +1,71 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type Redir struct { + RegisterProtoMin uint32 + RegisterProtoMax uint32 + Flags uint32 +} + +func (e *Redir) marshal(fam byte) ([]byte, error) { + var attrs []netlink.Attribute + if e.RegisterProtoMin > 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_REDIR_REG_PROTO_MIN, Data: binaryutil.BigEndian.PutUint32(e.RegisterProtoMin)}) + } + if e.RegisterProtoMax > 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_REDIR_REG_PROTO_MAX, Data: binaryutil.BigEndian.PutUint32(e.RegisterProtoMax)}) + } + if e.Flags > 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_REDIR_FLAGS, Data: binaryutil.BigEndian.PutUint32(e.Flags)}) + } + + data, err := netlink.MarshalAttributes(attrs) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("redir\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Redir) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_REDIR_REG_PROTO_MIN: + e.RegisterProtoMin = ad.Uint32() + case unix.NFTA_REDIR_REG_PROTO_MAX: + e.RegisterProtoMax = ad.Uint32() + case unix.NFTA_REDIR_FLAGS: + e.Flags = ad.Uint32() + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/reject.go b/vendor/github.com/google/nftables/expr/reject.go new file mode 100644 index 0000000000..a742626173 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/reject.go @@ -0,0 +1,59 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type Reject struct { + Type uint32 + Code uint8 +} + +func (e *Reject) marshal(fam byte) ([]byte, error) { + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_REJECT_TYPE, Data: binaryutil.BigEndian.PutUint32(e.Type)}, + {Type: unix.NFTA_REJECT_ICMP_CODE, Data: []byte{e.Code}}, + }) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("reject\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Reject) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_REJECT_TYPE: + e.Type = ad.Uint32() + case unix.NFTA_REJECT_ICMP_CODE: + e.Code = ad.Uint8() + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/rt.go b/vendor/github.com/google/nftables/expr/rt.go new file mode 100644 index 0000000000..c3be7ffc43 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/rt.go @@ -0,0 +1,55 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "fmt" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type RtKey uint32 + +const ( + RtClassid RtKey = unix.NFT_RT_CLASSID + RtNexthop4 RtKey = unix.NFT_RT_NEXTHOP4 + RtNexthop6 RtKey = unix.NFT_RT_NEXTHOP6 + RtTCPMSS RtKey = unix.NFT_RT_TCPMSS +) + +type Rt struct { + Register uint32 + Key RtKey +} + +func (e *Rt) marshal(fam byte) ([]byte, error) { + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_RT_KEY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Key))}, + {Type: unix.NFTA_RT_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, + }) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("rt\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Rt) unmarshal(fam byte, data []byte) error { + return fmt.Errorf("not yet implemented") +} diff --git a/vendor/github.com/google/nftables/expr/target.go b/vendor/github.com/google/nftables/expr/target.go new file mode 100644 index 0000000000..e531a9f7f3 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/target.go @@ -0,0 +1,79 @@ +package expr + +import ( + "bytes" + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/google/nftables/xt" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// See https://git.netfilter.org/libnftnl/tree/src/expr/target.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n28 +const XTablesExtensionNameMaxLen = 29 + +// See https://git.netfilter.org/libnftnl/tree/src/expr/target.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n30 +type Target struct { + Name string + Rev uint32 + Info xt.InfoAny +} + +func (e *Target) marshal(fam byte) ([]byte, error) { + // Per https://git.netfilter.org/libnftnl/tree/src/expr/target.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n38 + name := e.Name + // limit the extension name as (some) user-space tools do and leave room for + // trailing \x00 + if len(name) >= /* sic! */ XTablesExtensionNameMaxLen { + name = name[:XTablesExtensionNameMaxLen-1] // leave room for trailing \x00. + } + // Marshalling assumes that the correct Info type for the particular table + // family and Match revision has been set. + info, err := xt.Marshal(xt.TableFamily(fam), e.Rev, e.Info) + if err != nil { + return nil, err + } + attrs := []netlink.Attribute{ + {Type: unix.NFTA_TARGET_NAME, Data: []byte(name + "\x00")}, + {Type: unix.NFTA_TARGET_REV, Data: binaryutil.BigEndian.PutUint32(e.Rev)}, + {Type: unix.NFTA_TARGET_INFO, Data: info}, + } + + data, err := netlink.MarshalAttributes(attrs) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("target\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Target) unmarshal(fam byte, data []byte) error { + // Per https://git.netfilter.org/libnftnl/tree/src/expr/target.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n65 + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + + var info []byte + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_TARGET_NAME: + // We are forgiving here, accepting any length and even missing terminating \x00. + e.Name = string(bytes.TrimRight(ad.Bytes(), "\x00")) + case unix.NFTA_TARGET_REV: + e.Rev = ad.Uint32() + case unix.NFTA_TARGET_INFO: + info = ad.Bytes() + } + } + if err = ad.Err(); err != nil { + return err + } + e.Info, err = xt.Unmarshal(e.Name, xt.TableFamily(fam), e.Rev, info) + return err +} diff --git a/vendor/github.com/google/nftables/expr/tproxy.go b/vendor/github.com/google/nftables/expr/tproxy.go new file mode 100644 index 0000000000..ea936f34a7 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/tproxy.go @@ -0,0 +1,68 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +const ( + // NFTA_TPROXY_FAMILY defines attribute for a table family + NFTA_TPROXY_FAMILY = 0x01 + // NFTA_TPROXY_REG defines attribute for a register carrying redirection port value + NFTA_TPROXY_REG = 0x03 +) + +// TProxy defines struct with parameters for the transparent proxy +type TProxy struct { + Family byte + TableFamily byte + RegPort uint32 +} + +func (e *TProxy) marshal(fam byte) ([]byte, error) { + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: NFTA_TPROXY_FAMILY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Family))}, + {Type: NFTA_TPROXY_REG, Data: binaryutil.BigEndian.PutUint32(e.RegPort)}, + }) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("tproxy\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *TProxy) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case NFTA_TPROXY_FAMILY: + e.Family = ad.Uint8() + case NFTA_TPROXY_REG: + e.RegPort = ad.Uint32() + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/verdict.go b/vendor/github.com/google/nftables/expr/verdict.go new file mode 100644 index 0000000000..421fa066d1 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/verdict.go @@ -0,0 +1,128 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "bytes" + "encoding/binary" + "fmt" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// This code assembles the verdict structure, as expected by the +// nftables netlink API. +// For further information, consult: +// - netfilter.h (Linux kernel) +// - net/netfilter/nf_tables_api.c (Linux kernel) +// - src/expr/data_reg.c (linbnftnl) + +type Verdict struct { + Kind VerdictKind + Chain string +} + +type VerdictKind int64 + +// Verdicts, as per netfilter.h and netfilter/nf_tables.h. +const ( + VerdictReturn VerdictKind = iota - 5 + VerdictGoto + VerdictJump + VerdictBreak + VerdictContinue + VerdictDrop + VerdictAccept + VerdictStolen + VerdictQueue + VerdictRepeat + VerdictStop +) + +func (e *Verdict) marshal(fam byte) ([]byte, error) { + // A verdict is a tree of netlink attributes structured as follows: + // NFTA_LIST_ELEM | NLA_F_NESTED { + // NFTA_EXPR_NAME { "immediate\x00" } + // NFTA_EXPR_DATA | NLA_F_NESTED { + // NFTA_IMMEDIATE_DREG { NFT_REG_VERDICT } + // NFTA_IMMEDIATE_DATA | NLA_F_NESTED { + // the verdict code + // } + // } + // } + + attrs := []netlink.Attribute{ + {Type: unix.NFTA_VERDICT_CODE, Data: binaryutil.BigEndian.PutUint32(uint32(e.Kind))}, + } + if e.Chain != "" { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_VERDICT_CHAIN, Data: []byte(e.Chain + "\x00")}) + } + codeData, err := netlink.MarshalAttributes(attrs) + if err != nil { + return nil, err + } + + immData, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NLA_F_NESTED | unix.NFTA_DATA_VERDICT, Data: codeData}, + }) + if err != nil { + return nil, err + } + + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_IMMEDIATE_DREG, Data: binaryutil.BigEndian.PutUint32(unix.NFT_REG_VERDICT)}, + {Type: unix.NLA_F_NESTED | unix.NFTA_IMMEDIATE_DATA, Data: immData}, + }) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("immediate\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Verdict) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_IMMEDIATE_DATA: + nestedAD, err := netlink.NewAttributeDecoder(ad.Bytes()) + if err != nil { + return fmt.Errorf("nested NewAttributeDecoder() failed: %v", err) + } + for nestedAD.Next() { + switch nestedAD.Type() { + case unix.NFTA_DATA_VERDICT: + e.Kind = VerdictKind(int32(binaryutil.BigEndian.Uint32(nestedAD.Bytes()[4:8]))) + if len(nestedAD.Bytes()) > 12 { + e.Chain = string(bytes.Trim(nestedAD.Bytes()[12:], "\x00")) + } + } + } + if nestedAD.Err() != nil { + return fmt.Errorf("decoding immediate: %v", nestedAD.Err()) + } + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/flowtable.go b/vendor/github.com/google/nftables/flowtable.go new file mode 100644 index 0000000000..01df08eb03 --- /dev/null +++ b/vendor/github.com/google/nftables/flowtable.go @@ -0,0 +1,306 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nftables + +import ( + "encoding/binary" + "fmt" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +const ( + // not in ztypes_linux.go, added here + // https://cs.opensource.google/go/x/sys/+/c6bc011c:unix/ztypes_linux.go;l=1870-1892 + NFT_MSG_NEWFLOWTABLE = 0x16 + NFT_MSG_GETFLOWTABLE = 0x17 + NFT_MSG_DELFLOWTABLE = 0x18 +) + +const ( + // not in ztypes_linux.go, added here + // https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=84d12cfacf8ddd857a09435f3d982ab6250d250c#n1634 + _ = iota + NFTA_FLOWTABLE_TABLE + NFTA_FLOWTABLE_NAME + NFTA_FLOWTABLE_HOOK + NFTA_FLOWTABLE_USE + NFTA_FLOWTABLE_HANDLE + NFTA_FLOWTABLE_PAD + NFTA_FLOWTABLE_FLAGS +) + +const ( + // not in ztypes_linux.go, added here + // https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=84d12cfacf8ddd857a09435f3d982ab6250d250c#n1657 + _ = iota + NFTA_FLOWTABLE_HOOK_NUM + NFTA_FLOWTABLE_PRIORITY + NFTA_FLOWTABLE_DEVS +) + +const ( + // not in ztypes_linux.go, added here, used for flowtable device name specification + // https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=84d12cfacf8ddd857a09435f3d982ab6250d250c#n1709 + NFTA_DEVICE_NAME = 1 +) + +type FlowtableFlags uint32 + +const ( + _ FlowtableFlags = iota + FlowtableFlagsHWOffload + FlowtableFlagsCounter + FlowtableFlagsMask = (FlowtableFlagsHWOffload | FlowtableFlagsCounter) +) + +type FlowtableHook uint32 + +func FlowtableHookRef(h FlowtableHook) *FlowtableHook { + return &h +} + +var ( + // Only ingress is supported + // https://github.com/torvalds/linux/blob/b72018ab8236c3ae427068adeb94bdd3f20454ec/net/netfilter/nf_tables_api.c#L7378-L7379 + FlowtableHookIngress *FlowtableHook = FlowtableHookRef(unix.NF_NETDEV_INGRESS) +) + +type FlowtablePriority int32 + +func FlowtablePriorityRef(p FlowtablePriority) *FlowtablePriority { + return &p +} + +var ( + // As per man page: + // The priority can be a signed integer or filter which stands for 0. Addition and subtraction can be used to set relative priority, e.g. filter + 5 equals to 5. + // https://git.netfilter.org/nftables/tree/doc/nft.txt?id=8c600a843b7c0c1cc275ecc0603bd1fc57773e98#n712 + FlowtablePriorityFilter *FlowtablePriority = FlowtablePriorityRef(0) +) + +type Flowtable struct { + Table *Table + Name string + Hooknum *FlowtableHook + Priority *FlowtablePriority + Devices []string + Use uint32 + // Bitmask flags, can be HW_OFFLOAD or COUNTER + // https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=84d12cfacf8ddd857a09435f3d982ab6250d250c#n1621 + Flags FlowtableFlags + Handle uint64 +} + +func (cc *Conn) AddFlowtable(f *Flowtable) *Flowtable { + cc.mu.Lock() + defer cc.mu.Unlock() + + data := cc.marshalAttr([]netlink.Attribute{ + {Type: NFTA_FLOWTABLE_TABLE, Data: []byte(f.Table.Name)}, + {Type: NFTA_FLOWTABLE_NAME, Data: []byte(f.Name)}, + {Type: NFTA_FLOWTABLE_FLAGS, Data: binaryutil.BigEndian.PutUint32(uint32(f.Flags))}, + }) + + if f.Hooknum == nil { + f.Hooknum = FlowtableHookIngress + } + + if f.Priority == nil { + f.Priority = FlowtablePriorityFilter + } + + hookAttr := []netlink.Attribute{ + {Type: NFTA_FLOWTABLE_HOOK_NUM, Data: binaryutil.BigEndian.PutUint32(uint32(*f.Hooknum))}, + {Type: NFTA_FLOWTABLE_PRIORITY, Data: binaryutil.BigEndian.PutUint32(uint32(*f.Priority))}, + } + if len(f.Devices) > 0 { + devs := make([]netlink.Attribute, len(f.Devices)) + for i, d := range f.Devices { + devs[i] = netlink.Attribute{Type: NFTA_DEVICE_NAME, Data: []byte(d)} + } + hookAttr = append(hookAttr, netlink.Attribute{ + Type: unix.NLA_F_NESTED | NFTA_FLOWTABLE_DEVS, + Data: cc.marshalAttr(devs), + }) + } + data = append(data, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NLA_F_NESTED | NFTA_FLOWTABLE_HOOK, Data: cc.marshalAttr(hookAttr)}, + })...) + + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_NEWFLOWTABLE), + Flags: netlink.Request | netlink.Acknowledge | netlink.Create, + }, + Data: append(extraHeader(uint8(f.Table.Family), 0), data...), + }) + + return f +} + +func (cc *Conn) DelFlowtable(f *Flowtable) { + cc.mu.Lock() + defer cc.mu.Unlock() + + data := cc.marshalAttr([]netlink.Attribute{ + {Type: NFTA_FLOWTABLE_TABLE, Data: []byte(f.Table.Name)}, + {Type: NFTA_FLOWTABLE_NAME, Data: []byte(f.Name)}, + }) + + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_DELFLOWTABLE), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(uint8(f.Table.Family), 0), data...), + }) +} + +func (cc *Conn) ListFlowtables(t *Table) ([]*Flowtable, error) { + reply, err := cc.getFlowtables(t) + if err != nil { + return nil, err + } + + var fts []*Flowtable + for _, msg := range reply { + f, err := ftsFromMsg(msg) + if err != nil { + return nil, err + } + f.Table = t + fts = append(fts, f) + } + + return fts, nil +} + +func (cc *Conn) getFlowtables(t *Table) ([]netlink.Message, error) { + conn, closer, err := cc.netlinkConn() + if err != nil { + return nil, err + } + defer func() { _ = closer() }() + + attrs := []netlink.Attribute{ + {Type: NFTA_FLOWTABLE_TABLE, Data: []byte(t.Name + "\x00")}, + } + data, err := netlink.MarshalAttributes(attrs) + if err != nil { + return nil, err + } + + message := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_GETFLOWTABLE), + Flags: netlink.Request | netlink.Acknowledge | netlink.Dump, + }, + Data: append(extraHeader(uint8(t.Family), 0), data...), + } + + if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { + return nil, fmt.Errorf("SendMessages: %v", err) + } + + reply, err := receiveAckAware(conn, message.Header.Flags) + if err != nil { + return nil, fmt.Errorf("Receive: %v", err) + } + + return reply, nil +} + +func ftsFromMsg(msg netlink.Message) (*Flowtable, error) { + flowHeaderType := netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_NEWFLOWTABLE) + if got, want := msg.Header.Type, flowHeaderType; got != want { + return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + } + ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) + if err != nil { + return nil, err + } + ad.ByteOrder = binary.BigEndian + + var ft Flowtable + for ad.Next() { + switch ad.Type() { + case NFTA_FLOWTABLE_NAME: + ft.Name = ad.String() + case NFTA_FLOWTABLE_USE: + ft.Use = ad.Uint32() + case NFTA_FLOWTABLE_HANDLE: + ft.Handle = ad.Uint64() + case NFTA_FLOWTABLE_FLAGS: + ft.Flags = FlowtableFlags(ad.Uint32()) + case NFTA_FLOWTABLE_HOOK: + ad.Do(func(b []byte) error { + ft.Hooknum, ft.Priority, ft.Devices, err = ftsHookFromMsg(b) + return err + }) + } + } + return &ft, nil +} + +func ftsHookFromMsg(b []byte) (*FlowtableHook, *FlowtablePriority, []string, error) { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return nil, nil, nil, err + } + + ad.ByteOrder = binary.BigEndian + + var hooknum FlowtableHook + var prio FlowtablePriority + var devices []string + + for ad.Next() { + switch ad.Type() { + case NFTA_FLOWTABLE_HOOK_NUM: + hooknum = FlowtableHook(ad.Uint32()) + case NFTA_FLOWTABLE_PRIORITY: + prio = FlowtablePriority(ad.Uint32()) + case NFTA_FLOWTABLE_DEVS: + ad.Do(func(b []byte) error { + devices, err = devsFromMsg(b) + return err + }) + } + } + + return &hooknum, &prio, devices, nil +} + +func devsFromMsg(b []byte) ([]string, error) { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return nil, err + } + + ad.ByteOrder = binary.BigEndian + + devs := make([]string, 0) + for ad.Next() { + switch ad.Type() { + case NFTA_DEVICE_NAME: + devs = append(devs, ad.String()) + } + } + + return devs, nil +} diff --git a/vendor/github.com/google/nftables/internal/parseexprfunc/parseexprfunc.go b/vendor/github.com/google/nftables/internal/parseexprfunc/parseexprfunc.go new file mode 100644 index 0000000000..523859d755 --- /dev/null +++ b/vendor/github.com/google/nftables/internal/parseexprfunc/parseexprfunc.go @@ -0,0 +1,10 @@ +package parseexprfunc + +import ( + "github.com/mdlayher/netlink" +) + +var ( + ParseExprBytesFunc func(fam byte, ad *netlink.AttributeDecoder, b []byte) ([]interface{}, error) + ParseExprMsgFunc func(fam byte, b []byte) ([]interface{}, error) +) diff --git a/vendor/github.com/google/nftables/obj.go b/vendor/github.com/google/nftables/obj.go new file mode 100644 index 0000000000..08d43f4637 --- /dev/null +++ b/vendor/github.com/google/nftables/obj.go @@ -0,0 +1,224 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nftables + +import ( + "encoding/binary" + "fmt" + + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +var objHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ) + +// Obj represents a netfilter stateful object. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Stateful_objects +type Obj interface { + table() *Table + family() TableFamily + unmarshal(*netlink.AttributeDecoder) error + marshal(data bool) ([]byte, error) +} + +// AddObject adds the specified Obj. Alias of AddObj. +func (cc *Conn) AddObject(o Obj) Obj { + return cc.AddObj(o) +} + +// AddObj adds the specified Obj. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Stateful_objects +func (cc *Conn) AddObj(o Obj) Obj { + cc.mu.Lock() + defer cc.mu.Unlock() + data, err := o.marshal(true) + if err != nil { + cc.setErr(err) + return nil + } + + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ), + Flags: netlink.Request | netlink.Acknowledge | netlink.Create, + }, + Data: append(extraHeader(uint8(o.family()), 0), data...), + }) + return o +} + +// DeleteObject deletes the specified Obj +func (cc *Conn) DeleteObject(o Obj) { + cc.mu.Lock() + defer cc.mu.Unlock() + data, err := o.marshal(false) + if err != nil { + cc.setErr(err) + return + } + + data = append(data, cc.marshalAttr([]netlink.Attribute{{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA}})...) + + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELOBJ), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(uint8(o.family()), 0), data...), + }) +} + +// GetObj is a legacy method that return all Obj that belongs +// to the same table as the given one +func (cc *Conn) GetObj(o Obj) ([]Obj, error) { + return cc.getObj(nil, o.table(), unix.NFT_MSG_GETOBJ) +} + +// GetObjReset is a legacy method that reset all Obj that belongs +// the same table as the given one +func (cc *Conn) GetObjReset(o Obj) ([]Obj, error) { + return cc.getObj(nil, o.table(), unix.NFT_MSG_GETOBJ_RESET) +} + +// GetObject gets the specified Object +func (cc *Conn) GetObject(o Obj) (Obj, error) { + objs, err := cc.getObj(o, o.table(), unix.NFT_MSG_GETOBJ) + + if len(objs) == 0 { + return nil, err + } + + return objs[0], err +} + +// GetObjects get all the Obj that belongs to the given table +func (cc *Conn) GetObjects(t *Table) ([]Obj, error) { + return cc.getObj(nil, t, unix.NFT_MSG_GETOBJ) +} + +// ResetObject reset the given Obj +func (cc *Conn) ResetObject(o Obj) (Obj, error) { + objs, err := cc.getObj(o, o.table(), unix.NFT_MSG_GETOBJ_RESET) + + if len(objs) == 0 { + return nil, err + } + + return objs[0], err +} + +// ResetObjects reset all the Obj that belongs to the given table +func (cc *Conn) ResetObjects(t *Table) ([]Obj, error) { + return cc.getObj(nil, t, unix.NFT_MSG_GETOBJ_RESET) +} + +func objFromMsg(msg netlink.Message) (Obj, error) { + if got, want := msg.Header.Type, objHeaderType; got != want { + return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + } + ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) + if err != nil { + return nil, err + } + ad.ByteOrder = binary.BigEndian + var ( + table *Table + name string + objectType uint32 + ) + const NFT_OBJECT_COUNTER = 1 // TODO: get into x/sys/unix + for ad.Next() { + switch ad.Type() { + case unix.NFTA_OBJ_TABLE: + table = &Table{Name: ad.String(), Family: TableFamily(msg.Data[0])} + case unix.NFTA_OBJ_NAME: + name = ad.String() + case unix.NFTA_OBJ_TYPE: + objectType = ad.Uint32() + case unix.NFTA_OBJ_DATA: + switch objectType { + case NFT_OBJECT_COUNTER: + o := CounterObj{ + Table: table, + Name: name, + } + + ad.Do(func(b []byte) error { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + return o.unmarshal(ad) + }) + return &o, ad.Err() + } + } + } + if err := ad.Err(); err != nil { + return nil, err + } + return nil, fmt.Errorf("malformed stateful object") +} + +func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) { + conn, closer, err := cc.netlinkConn() + if err != nil { + return nil, err + } + defer func() { _ = closer() }() + + var data []byte + var flags netlink.HeaderFlags + + if o != nil { + data, err = o.marshal(false) + } else { + flags = netlink.Dump + data, err = netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")}, + }) + } + if err != nil { + return nil, err + } + + message := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | msgType), + Flags: netlink.Request | netlink.Acknowledge | flags, + }, + Data: append(extraHeader(uint8(t.Family), 0), data...), + } + + if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { + return nil, fmt.Errorf("SendMessages: %v", err) + } + + reply, err := receiveAckAware(conn, message.Header.Flags) + if err != nil { + return nil, fmt.Errorf("Receive: %v", err) + } + var objs []Obj + for _, msg := range reply { + o, err := objFromMsg(msg) + if err != nil { + return nil, err + } + objs = append(objs, o) + } + + return objs, nil +} diff --git a/vendor/github.com/google/nftables/rule.go b/vendor/github.com/google/nftables/rule.go new file mode 100644 index 0000000000..f004e45105 --- /dev/null +++ b/vendor/github.com/google/nftables/rule.go @@ -0,0 +1,270 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nftables + +import ( + "encoding/binary" + "fmt" + + "github.com/google/nftables/binaryutil" + "github.com/google/nftables/expr" + "github.com/google/nftables/internal/parseexprfunc" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +var ruleHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE) + +type ruleOperation uint32 + +// Possible PayloadOperationType values. +const ( + operationAdd ruleOperation = iota + operationInsert + operationReplace +) + +// A Rule does something with a packet. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Simple_rule_management +type Rule struct { + Table *Table + Chain *Chain + Position uint64 + Handle uint64 + // The list of possible flags are specified by nftnl_rule_attr, see + // https://git.netfilter.org/libnftnl/tree/include/libnftnl/rule.h#n21 + // Current nftables go implementation supports only + // NFTNL_RULE_POSITION flag for setting rule at position 0 + Flags uint32 + Exprs []expr.Any + UserData []byte +} + +// GetRule returns the rules in the specified table and chain. +// +// Deprecated: use GetRules instead. +func (cc *Conn) GetRule(t *Table, c *Chain) ([]*Rule, error) { + return cc.GetRules(t, c) +} + +// GetRules returns the rules in the specified table and chain. +func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) { + conn, closer, err := cc.netlinkConn() + if err != nil { + return nil, err + } + defer func() { _ = closer() }() + + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")}, + {Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")}, + }) + if err != nil { + return nil, err + } + + message := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE), + Flags: netlink.Request | netlink.Acknowledge | netlink.Dump | unix.NLM_F_ECHO, + }, + Data: append(extraHeader(uint8(t.Family), 0), data...), + } + + if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { + return nil, fmt.Errorf("SendMessages: %v", err) + } + + reply, err := receiveAckAware(conn, message.Header.Flags) + if err != nil { + return nil, fmt.Errorf("Receive: %v", err) + } + var rules []*Rule + for _, msg := range reply { + r, err := ruleFromMsg(t.Family, msg) + if err != nil { + return nil, err + } + // Carry over all Table attributes (including Family), as the Table + // object which ruleFromMsg creates only contains the name. + r.Table = t + rules = append(rules, r) + } + + return rules, nil +} + +// AddRule adds the specified Rule +func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule { + cc.mu.Lock() + defer cc.mu.Unlock() + exprAttrs := make([]netlink.Attribute, len(r.Exprs)) + for idx, expr := range r.Exprs { + exprAttrs[idx] = netlink.Attribute{ + Type: unix.NLA_F_NESTED | unix.NFTA_LIST_ELEM, + Data: cc.marshalExpr(byte(r.Table.Family), expr), + } + } + + data := cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_RULE_TABLE, Data: []byte(r.Table.Name + "\x00")}, + {Type: unix.NFTA_RULE_CHAIN, Data: []byte(r.Chain.Name + "\x00")}, + }) + + if r.Handle != 0 { + data = append(data, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_RULE_HANDLE, Data: binaryutil.BigEndian.PutUint64(r.Handle)}, + })...) + } + + data = append(data, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NLA_F_NESTED | unix.NFTA_RULE_EXPRESSIONS, Data: cc.marshalAttr(exprAttrs)}, + })...) + + if compatPolicy, err := getCompatPolicy(r.Exprs); err != nil { + cc.setErr(err) + } else if compatPolicy != nil { + data = append(data, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NLA_F_NESTED | unix.NFTA_RULE_COMPAT, Data: cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_RULE_COMPAT_PROTO, Data: binaryutil.BigEndian.PutUint32(compatPolicy.Proto)}, + {Type: unix.NFTA_RULE_COMPAT_FLAGS, Data: binaryutil.BigEndian.PutUint32(compatPolicy.Flag & nft_RULE_COMPAT_F_MASK)}, + })}, + })...) + } + + msgData := []byte{} + + msgData = append(msgData, data...) + var flags netlink.HeaderFlags + if r.UserData != nil { + msgData = append(msgData, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_RULE_USERDATA, Data: r.UserData}, + })...) + } + + switch op { + case operationAdd: + flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO | unix.NLM_F_APPEND + case operationInsert: + flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO + case operationReplace: + flags = netlink.Request | netlink.Acknowledge | netlink.Replace | unix.NLM_F_ECHO | unix.NLM_F_REPLACE + } + + if r.Position != 0 || (r.Flags&(1< 32/SetConcatTypeBits { + return SetDatatype{}, ErrTooManyTypes + } + + var magic, bytes uint32 + names := make([]string, len(types)) + for i, t := range types { + bytes += t.Bytes + // concatenated types pad the length to multiples of the register size (4 bytes) + // see https://git.netfilter.org/nftables/tree/src/datatype.c?id=488356b895024d0944b20feb1f930558726e0877#n1162 + if t.Bytes%4 != 0 { + bytes += 4 - (t.Bytes % 4) + } + names[i] = t.Name + + magic <<= SetConcatTypeBits + magic |= t.nftMagic & SetConcatTypeMask + } + return SetDatatype{Name: strings.Join(names, " . "), Bytes: bytes, nftMagic: magic}, nil +} + +// ConcatSetTypeElements uses the ConcatSetType name to calculate and return +// a list of base types which were used to construct the concatenated type +func ConcatSetTypeElements(t SetDatatype) []SetDatatype { + names := strings.Split(t.Name, " . ") + types := make([]SetDatatype, len(names)) + for i, n := range names { + types[i] = nftDatatypes[n] + } + return types +} + +// Set represents an nftables set. Anonymous sets are only valid within the +// context of a single batch. +type Set struct { + Table *Table + ID uint32 + Name string + Anonymous bool + Constant bool + Interval bool + IsMap bool + HasTimeout bool + Counter bool + // Can be updated per evaluation path, per `nft list ruleset` + // indicates that set contains "flags dynamic" + // https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=84d12cfacf8ddd857a09435f3d982ab6250d250c#n298 + Dynamic bool + // Indicates that the set contains a concatenation + // https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h?id=d1289bff58e1878c3162f574c603da993e29b113#n306 + Concatenation bool + Timeout time.Duration + KeyType SetDatatype + DataType SetDatatype +} + +// SetElement represents a data point within a set. +type SetElement struct { + Key []byte + Val []byte + // Field used for definition of ending interval value in concatenated types + // https://git.netfilter.org/libnftnl/tree/include/set_elem.h?id=e2514c0eff4da7e8e0aabd410f7b7d0b7564c880#n11 + KeyEnd []byte + IntervalEnd bool + // To support vmap, a caller must be able to pass Verdict type of data. + // If IsMap is true and VerdictData is not nil, then Val of SetElement will be ignored + // and VerdictData will be wrapped into Attribute data. + VerdictData *expr.Verdict + // To support aging of set elements + Timeout time.Duration +} + +func (s *SetElement) decode() func(b []byte) error { + return func(b []byte) error { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return fmt.Errorf("failed to create nested attribute decoder: %v", err) + } + ad.ByteOrder = binary.BigEndian + + for ad.Next() { + switch ad.Type() { + case unix.NFTA_SET_ELEM_KEY: + s.Key, err = decodeElement(ad.Bytes()) + if err != nil { + return err + } + case NFTA_SET_ELEM_KEY_END: + s.KeyEnd, err = decodeElement(ad.Bytes()) + if err != nil { + return err + } + case unix.NFTA_SET_ELEM_DATA: + s.Val, err = decodeElement(ad.Bytes()) + if err != nil { + return err + } + case unix.NFTA_SET_ELEM_FLAGS: + flags := ad.Uint32() + s.IntervalEnd = (flags & unix.NFT_SET_ELEM_INTERVAL_END) != 0 + case unix.NFTA_SET_ELEM_TIMEOUT: + s.Timeout = time.Duration(time.Millisecond * time.Duration(ad.Uint64())) + } + } + return ad.Err() + } +} + +func decodeElement(d []byte) ([]byte, error) { + ad, err := netlink.NewAttributeDecoder(d) + if err != nil { + return nil, fmt.Errorf("failed to create nested attribute decoder: %v", err) + } + ad.ByteOrder = binary.BigEndian + var b []byte + for ad.Next() { + switch ad.Type() { + case unix.NFTA_SET_ELEM_KEY: + fallthrough + case unix.NFTA_SET_ELEM_DATA: + b = ad.Bytes() + } + } + if err := ad.Err(); err != nil { + return nil, err + } + return b, nil +} + +// SetAddElements applies data points to an nftables set. +func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error { + cc.mu.Lock() + defer cc.mu.Unlock() + if s.Anonymous { + return errors.New("anonymous sets cannot be updated") + } + + elements, err := s.makeElemList(vals, s.ID) + if err != nil { + return err + } + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM), + Flags: netlink.Request | netlink.Acknowledge | netlink.Create, + }, + Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...), + }) + + return nil +} + +func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, error) { + var elements []netlink.Attribute + + for i, v := range vals { + item := make([]netlink.Attribute, 0) + var flags uint32 + if v.IntervalEnd { + flags |= unix.NFT_SET_ELEM_INTERVAL_END + item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_FLAGS | unix.NLA_F_NESTED, Data: binaryutil.BigEndian.PutUint32(flags)}) + } + + encodedKey, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Key}}) + if err != nil { + return nil, fmt.Errorf("marshal key %d: %v", i, err) + } + + item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey}) + if len(v.KeyEnd) > 0 { + encodedKeyEnd, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.KeyEnd}}) + if err != nil { + return nil, fmt.Errorf("marshal key end %d: %v", i, err) + } + item = append(item, netlink.Attribute{Type: NFTA_SET_ELEM_KEY_END | unix.NLA_F_NESTED, Data: encodedKeyEnd}) + } + if s.HasTimeout && v.Timeout != 0 { + // Set has Timeout flag set, which means an individual element can specify its own timeout. + item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_TIMEOUT, Data: binaryutil.BigEndian.PutUint64(uint64(v.Timeout.Milliseconds()))}) + } + // The following switch statement deal with 3 different types of elements. + // 1. v is an element of vmap + // 2. v is an element of a regular map + // 3. v is an element of a regular set (default) + switch { + case v.VerdictData != nil: + // Since VerdictData is not nil, v is vmap element, need to add to the attributes + encodedVal := []byte{} + encodedKind, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(v.VerdictData.Kind))}, + }) + if err != nil { + return nil, fmt.Errorf("marshal item %d: %v", i, err) + } + encodedVal = append(encodedVal, encodedKind...) + if len(v.VerdictData.Chain) != 0 { + encodedChain, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_SET_ELEM_DATA, Data: []byte(v.VerdictData.Chain + "\x00")}, + }) + if err != nil { + return nil, fmt.Errorf("marshal item %d: %v", i, err) + } + encodedVal = append(encodedVal, encodedChain...) + } + encodedVerdict, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}}) + if err != nil { + return nil, fmt.Errorf("marshal item %d: %v", i, err) + } + item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVerdict}) + case len(v.Val) > 0: + // Since v.Val's length is not 0 then, v is a regular map element, need to add to the attributes + encodedVal, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Val}}) + if err != nil { + return nil, fmt.Errorf("marshal item %d: %v", i, err) + } + + item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}) + default: + // If niether of previous cases matche, it means 'e' is an element of a regular Set, no need to add to the attributes + } + + encodedItem, err := netlink.MarshalAttributes(item) + if err != nil { + return nil, fmt.Errorf("marshal item %d: %v", i, err) + } + elements = append(elements, netlink.Attribute{Type: uint16(i+1) | unix.NLA_F_NESTED, Data: encodedItem}) + } + + encodedElem, err := netlink.MarshalAttributes(elements) + if err != nil { + return nil, fmt.Errorf("marshal elements: %v", err) + } + + return []netlink.Attribute{ + {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, + {Type: unix.NFTA_LOOKUP_SET_ID, Data: binaryutil.BigEndian.PutUint32(id)}, + {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, + {Type: unix.NFTA_SET_ELEM_LIST_ELEMENTS | unix.NLA_F_NESTED, Data: encodedElem}, + }, nil +} + +// AddSet adds the specified Set. +func (cc *Conn) AddSet(s *Set, vals []SetElement) error { + cc.mu.Lock() + defer cc.mu.Unlock() + // Based on nft implementation & linux source. + // Link: https://github.com/torvalds/linux/blob/49a57857aeea06ca831043acbb0fa5e0f50602fd/net/netfilter/nf_tables_api.c#L3395 + // Another reference: https://git.netfilter.org/nftables/tree/src + + if s.Anonymous && !s.Constant { + return errors.New("anonymous structs must be constant") + } + + if s.ID == 0 { + allocSetID++ + s.ID = allocSetID + if s.Anonymous { + s.Name = "__set%d" + if s.IsMap { + s.Name = "__map%d" + } + } + } + + var flags uint32 + if s.Anonymous { + flags |= unix.NFT_SET_ANONYMOUS + } + if s.Constant { + flags |= unix.NFT_SET_CONSTANT + } + if s.Interval { + flags |= unix.NFT_SET_INTERVAL + } + if s.IsMap { + flags |= unix.NFT_SET_MAP + } + if s.HasTimeout { + flags |= unix.NFT_SET_TIMEOUT + } + if s.Dynamic { + flags |= unix.NFT_SET_EVAL + } + if s.Concatenation { + flags |= NFT_SET_CONCAT + } + tableInfo := []netlink.Attribute{ + {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, + {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, + {Type: unix.NFTA_SET_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)}, + {Type: unix.NFTA_SET_KEY_TYPE, Data: binaryutil.BigEndian.PutUint32(s.KeyType.nftMagic)}, + {Type: unix.NFTA_SET_KEY_LEN, Data: binaryutil.BigEndian.PutUint32(s.KeyType.Bytes)}, + {Type: unix.NFTA_SET_ID, Data: binaryutil.BigEndian.PutUint32(s.ID)}, + } + if s.IsMap { + // Check if it is vmap case + if s.DataType.nftMagic == 1 { + // For Verdict data type, the expected magic is 0xfffff0 + tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_DATA_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(unix.NFT_DATA_VERDICT))}, + netlink.Attribute{Type: unix.NFTA_SET_DATA_LEN, Data: binaryutil.BigEndian.PutUint32(s.DataType.Bytes)}) + } else { + tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_DATA_TYPE, Data: binaryutil.BigEndian.PutUint32(s.DataType.nftMagic)}, + netlink.Attribute{Type: unix.NFTA_SET_DATA_LEN, Data: binaryutil.BigEndian.PutUint32(s.DataType.Bytes)}) + } + } + if s.HasTimeout && s.Timeout != 0 { + // If Set's global timeout is specified, add it to set's attributes + tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_TIMEOUT, Data: binaryutil.BigEndian.PutUint64(uint64(s.Timeout.Milliseconds()))}) + } + if s.Constant { + // nft cli tool adds the number of elements to set/map's descriptor + // It make sense to do only if a set or map are constant, otherwise skip NFTA_SET_DESC attribute + numberOfElements, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(len(vals)))}, + }) + if err != nil { + return fmt.Errorf("fail to marshal number of elements %d: %v", len(vals), err) + } + tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: numberOfElements}) + } + if s.Concatenation { + // Length of concatenated types is a must, otherwise segfaults when executing nft list ruleset + var concatDefinition []byte + elements := ConcatSetTypeElements(s.KeyType) + for i, v := range elements { + // Marshal base type size value + valData, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(v.Bytes)}, + }) + if err != nil { + return fmt.Errorf("fail to marshal element key size %d: %v", i, err) + } + // Marshal base type size description + descSize, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_SET_DESC_SIZE, Data: valData}, + }) + + concatDefinition = append(concatDefinition, descSize...) + } + // Marshal all base type descriptions into concatenation size description + concatBytes, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NLA_F_NESTED | NFTA_SET_DESC_CONCAT, Data: concatDefinition}}) + if err != nil { + return fmt.Errorf("fail to marshal concat definition %v", err) + } + // Marshal concat size description as set description + tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: concatBytes}) + } + if s.Anonymous || s.Constant || s.Interval { + tableInfo = append(tableInfo, + // Semantically useless - kept for binary compatability with nft + netlink.Attribute{Type: unix.NFTA_SET_USERDATA, Data: []byte("\x00\x04\x02\x00\x00\x00")}) + } else if !s.IsMap { + // Per https://git.netfilter.org/nftables/tree/src/mnl.c?id=187c6d01d35722618c2711bbc49262c286472c8f#n1165 + tableInfo = append(tableInfo, + netlink.Attribute{Type: unix.NFTA_SET_USERDATA, Data: []byte("\x00\x04\x01\x00\x00\x00")}) + } + if s.Counter { + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_LIST_ELEM, Data: []byte("counter\x00")}, + {Type: unix.NFTA_SET_ELEM_PAD | unix.NFTA_SET_ELEM_DATA, Data: []byte{}}, + }) + if err != nil { + return err + } + tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | NFTA_SET_ELEM_EXPRESSIONS, Data: data}) + } + + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET), + Flags: netlink.Request | netlink.Acknowledge | netlink.Create, + }, + Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(tableInfo)...), + }) + + // Set the values of the set if initial values were provided. + if len(vals) > 0 { + hdrType := unix.NFT_MSG_NEWSETELEM + elements, err := s.makeElemList(vals, s.ID) + if err != nil { + return err + } + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType), + Flags: netlink.Request | netlink.Acknowledge | netlink.Create, + }, + Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...), + }) + } + + return nil +} + +// DelSet deletes a specific set, along with all elements it contains. +func (cc *Conn) DelSet(s *Set) { + cc.mu.Lock() + defer cc.mu.Unlock() + data := cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, + {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, + }) + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSET), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(uint8(s.Table.Family), 0), data...), + }) +} + +// SetDeleteElements deletes data points from an nftables set. +func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error { + cc.mu.Lock() + defer cc.mu.Unlock() + if s.Anonymous { + return errors.New("anonymous sets cannot be updated") + } + + elements, err := s.makeElemList(vals, s.ID) + if err != nil { + return err + } + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM), + Flags: netlink.Request | netlink.Acknowledge | netlink.Create, + }, + Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...), + }) + + return nil +} + +// FlushSet deletes all data points from an nftables set. +func (cc *Conn) FlushSet(s *Set) { + cc.mu.Lock() + defer cc.mu.Unlock() + data := cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, + {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, + }) + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(uint8(s.Table.Family), 0), data...), + }) +} + +var setHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET) + +func setsFromMsg(msg netlink.Message) (*Set, error) { + if got, want := msg.Header.Type, setHeaderType; got != want { + return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + } + ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) + if err != nil { + return nil, err + } + ad.ByteOrder = binary.BigEndian + + var set Set + for ad.Next() { + switch ad.Type() { + case unix.NFTA_SET_NAME: + set.Name = ad.String() + case unix.NFTA_SET_ID: + set.ID = binary.BigEndian.Uint32(ad.Bytes()) + case unix.NFTA_SET_TIMEOUT: + set.Timeout = time.Duration(time.Millisecond * time.Duration(binary.BigEndian.Uint64(ad.Bytes()))) + set.HasTimeout = true + case unix.NFTA_SET_FLAGS: + flags := ad.Uint32() + set.Constant = (flags & unix.NFT_SET_CONSTANT) != 0 + set.Anonymous = (flags & unix.NFT_SET_ANONYMOUS) != 0 + set.Interval = (flags & unix.NFT_SET_INTERVAL) != 0 + set.IsMap = (flags & unix.NFT_SET_MAP) != 0 + set.HasTimeout = (flags & unix.NFT_SET_TIMEOUT) != 0 + set.Concatenation = (flags & NFT_SET_CONCAT) != 0 + case unix.NFTA_SET_KEY_TYPE: + nftMagic := ad.Uint32() + if invalidMagic, ok := validateKeyType(nftMagic); !ok { + return nil, fmt.Errorf("could not determine key type %+v", invalidMagic) + } + set.KeyType.nftMagic = nftMagic + for _, dt := range nftDatatypes { + // If this is a non-concatenated type, we can assign the descriptor. + if nftMagic == dt.nftMagic { + set.KeyType = dt + break + } + } + case unix.NFTA_SET_DATA_TYPE: + nftMagic := ad.Uint32() + // Special case for the data type verdict, in the message it is stored as 0xffffff00 but it is defined as 1 + if nftMagic == 0xffffff00 { + set.KeyType = TypeVerdict + break + } + for _, dt := range nftDatatypes { + if nftMagic == dt.nftMagic { + set.DataType = dt + break + } + } + if set.DataType.nftMagic == 0 { + return nil, fmt.Errorf("could not determine data type %x", nftMagic) + } + } + } + return &set, nil +} + +func validateKeyType(bits uint32) ([]uint32, bool) { + var unpackTypes []uint32 + var invalidTypes []uint32 + found := false + valid := true + for bits != 0 { + unpackTypes = append(unpackTypes, bits&SetConcatTypeMask) + bits = bits >> SetConcatTypeBits + } + for _, t := range unpackTypes { + for _, dt := range nftDatatypes { + if t == dt.nftMagic { + found = true + } + } + if !found { + invalidTypes = append(invalidTypes, t) + valid = false + } + found = false + } + return invalidTypes, valid +} + +var elemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM) + +func elementsFromMsg(msg netlink.Message) ([]SetElement, error) { + if got, want := msg.Header.Type, elemHeaderType; got != want { + return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + } + ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) + if err != nil { + return nil, err + } + ad.ByteOrder = binary.BigEndian + + var elements []SetElement + for ad.Next() { + b := ad.Bytes() + if ad.Type() == unix.NFTA_SET_ELEM_LIST_ELEMENTS { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return nil, err + } + ad.ByteOrder = binary.BigEndian + + for ad.Next() { + var elem SetElement + switch ad.Type() { + case unix.NFTA_LIST_ELEM: + ad.Do(elem.decode()) + } + elements = append(elements, elem) + } + } + } + return elements, nil +} + +// GetSets returns the sets in the specified table. +func (cc *Conn) GetSets(t *Table) ([]*Set, error) { + conn, closer, err := cc.netlinkConn() + if err != nil { + return nil, err + } + defer func() { _ = closer() }() + + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_SET_TABLE, Data: []byte(t.Name + "\x00")}, + }) + if err != nil { + return nil, err + } + + message := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETSET), + Flags: netlink.Request | netlink.Acknowledge | netlink.Dump, + }, + Data: append(extraHeader(uint8(t.Family), 0), data...), + } + + if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { + return nil, fmt.Errorf("SendMessages: %v", err) + } + + reply, err := receiveAckAware(conn, message.Header.Flags) + if err != nil { + return nil, fmt.Errorf("Receive: %v", err) + } + var sets []*Set + for _, msg := range reply { + s, err := setsFromMsg(msg) + if err != nil { + return nil, err + } + s.Table = &Table{Name: t.Name, Use: t.Use, Flags: t.Flags, Family: t.Family} + sets = append(sets, s) + } + + return sets, nil +} + +// GetSetByName returns the set in the specified table if matching name is found. +func (cc *Conn) GetSetByName(t *Table, name string) (*Set, error) { + conn, closer, err := cc.netlinkConn() + if err != nil { + return nil, err + } + defer func() { _ = closer() }() + + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_SET_TABLE, Data: []byte(t.Name + "\x00")}, + {Type: unix.NFTA_SET_NAME, Data: []byte(name + "\x00")}, + }) + if err != nil { + return nil, err + } + + message := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETSET), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(uint8(t.Family), 0), data...), + } + + if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { + return nil, fmt.Errorf("SendMessages: %w", err) + } + + reply, err := receiveAckAware(conn, message.Header.Flags) + if err != nil { + return nil, fmt.Errorf("Receive: %w", err) + } + + if len(reply) != 1 { + return nil, fmt.Errorf("Receive: expected to receive 1 message but got %d", len(reply)) + } + rs, err := setsFromMsg(reply[0]) + if err != nil { + return nil, err + } + rs.Table = &Table{Name: t.Name, Use: t.Use, Flags: t.Flags, Family: t.Family} + + return rs, nil +} + +// GetSetElements returns the elements in the specified set. +func (cc *Conn) GetSetElements(s *Set) ([]SetElement, error) { + conn, closer, err := cc.netlinkConn() + if err != nil { + return nil, err + } + defer func() { _ = closer() }() + + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, + {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, + }) + if err != nil { + return nil, err + } + + message := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETSETELEM), + Flags: netlink.Request | netlink.Acknowledge | netlink.Dump, + }, + Data: append(extraHeader(uint8(s.Table.Family), 0), data...), + } + + if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { + return nil, fmt.Errorf("SendMessages: %v", err) + } + + reply, err := receiveAckAware(conn, message.Header.Flags) + if err != nil { + return nil, fmt.Errorf("Receive: %v", err) + } + var elems []SetElement + for _, msg := range reply { + s, err := elementsFromMsg(msg) + if err != nil { + return nil, err + } + elems = append(elems, s...) + } + + return elems, nil +} diff --git a/vendor/github.com/google/nftables/table.go b/vendor/github.com/google/nftables/table.go new file mode 100644 index 0000000000..7ea4aa9ed3 --- /dev/null +++ b/vendor/github.com/google/nftables/table.go @@ -0,0 +1,167 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nftables + +import ( + "fmt" + + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +var tableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE) + +// TableFamily specifies the address family for this table. +type TableFamily byte + +// Possible TableFamily values. +const ( + TableFamilyUnspecified TableFamily = unix.NFPROTO_UNSPEC + TableFamilyINet TableFamily = unix.NFPROTO_INET + TableFamilyIPv4 TableFamily = unix.NFPROTO_IPV4 + TableFamilyIPv6 TableFamily = unix.NFPROTO_IPV6 + TableFamilyARP TableFamily = unix.NFPROTO_ARP + TableFamilyNetdev TableFamily = unix.NFPROTO_NETDEV + TableFamilyBridge TableFamily = unix.NFPROTO_BRIDGE +) + +// A Table contains Chains. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_tables +type Table struct { + Name string // NFTA_TABLE_NAME + Use uint32 // NFTA_TABLE_USE (Number of chains in table) + Flags uint32 // NFTA_TABLE_FLAGS + Family TableFamily +} + +// DelTable deletes a specific table, along with all chains/rules it contains. +func (cc *Conn) DelTable(t *Table) { + cc.mu.Lock() + defer cc.mu.Unlock() + data := cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, + {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, + }) + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(uint8(t.Family), 0), data...), + }) +} + +// AddTable adds the specified Table. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_tables +func (cc *Conn) AddTable(t *Table) *Table { + cc.mu.Lock() + defer cc.mu.Unlock() + data := cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, + {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, + }) + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE), + Flags: netlink.Request | netlink.Acknowledge | netlink.Create, + }, + Data: append(extraHeader(uint8(t.Family), 0), data...), + }) + return t +} + +// FlushTable removes all rules in all chains within the specified Table. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_tables#Flushing_tables +func (cc *Conn) FlushTable(t *Table) { + cc.mu.Lock() + defer cc.mu.Unlock() + data := cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")}, + }) + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(uint8(t.Family), 0), data...), + }) +} + +// ListTables returns currently configured tables in the kernel +func (cc *Conn) ListTables() ([]*Table, error) { + return cc.ListTablesOfFamily(TableFamilyUnspecified) +} + +// ListTablesOfFamily returns currently configured tables for the specified table family +// in the kernel. It lists all tables if family is TableFamilyUnspecified. +func (cc *Conn) ListTablesOfFamily(family TableFamily) ([]*Table, error) { + conn, closer, err := cc.netlinkConn() + if err != nil { + return nil, err + } + defer func() { _ = closer() }() + + msg := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETTABLE), + Flags: netlink.Request | netlink.Dump, + }, + Data: extraHeader(uint8(family), 0), + } + + response, err := conn.Execute(msg) + if err != nil { + return nil, err + } + + var tables []*Table + for _, m := range response { + t, err := tableFromMsg(m) + if err != nil { + return nil, err + } + + tables = append(tables, t) + } + + return tables, nil +} + +func tableFromMsg(msg netlink.Message) (*Table, error) { + if got, want := msg.Header.Type, tableHeaderType; got != want { + return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + } + + var t Table + t.Family = TableFamily(msg.Data[0]) + + ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) + if err != nil { + return nil, err + } + + for ad.Next() { + switch ad.Type() { + case unix.NFTA_TABLE_NAME: + t.Name = ad.String() + case unix.NFTA_TABLE_USE: + t.Use = ad.Uint32() + case unix.NFTA_TABLE_FLAGS: + t.Flags = ad.Uint32() + } + } + + return &t, nil +} diff --git a/vendor/github.com/google/nftables/util.go b/vendor/github.com/google/nftables/util.go new file mode 100644 index 0000000000..de8880720e --- /dev/null +++ b/vendor/github.com/google/nftables/util.go @@ -0,0 +1,27 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nftables + +import ( + "github.com/google/nftables/binaryutil" + "golang.org/x/sys/unix" +) + +func extraHeader(family uint8, resID uint16) []byte { + return append([]byte{ + family, + unix.NFNETLINK_V0, + }, binaryutil.BigEndian.PutUint16(resID)...) +} diff --git a/vendor/github.com/google/nftables/xt/info.go b/vendor/github.com/google/nftables/xt/info.go new file mode 100644 index 0000000000..0cf9ab95e0 --- /dev/null +++ b/vendor/github.com/google/nftables/xt/info.go @@ -0,0 +1,94 @@ +package xt + +import ( + "golang.org/x/sys/unix" +) + +// TableFamily specifies the address family of the table Match or Target Info +// data is contained in. On purpose, we don't import the expr package here in +// order to keep the option open to import this package instead into expr. +type TableFamily byte + +// InfoAny is a (un)marshaling implemented by any info type. +type InfoAny interface { + marshal(fam TableFamily, rev uint32) ([]byte, error) + unmarshal(fam TableFamily, rev uint32, data []byte) error +} + +// Marshal a Match or Target Info type into its binary representation. +func Marshal(fam TableFamily, rev uint32, info InfoAny) ([]byte, error) { + return info.marshal(fam, rev) +} + +// Unmarshal Info binary payload into its corresponding dedicated type as +// indicated by the name argument. In several cases, unmarshalling depends on +// the specific table family the Target or Match expression with the info +// payload belongs to, as well as the specific info structure revision. +func Unmarshal(name string, fam TableFamily, rev uint32, data []byte) (InfoAny, error) { + var i InfoAny + switch name { + case "addrtype": + switch rev { + case 0: + i = &AddrType{} + case 1: + i = &AddrTypeV1{} + } + case "conntrack": + switch rev { + case 1: + i = &ConntrackMtinfo1{} + case 2: + i = &ConntrackMtinfo2{} + case 3: + i = &ConntrackMtinfo3{} + } + case "tcp": + i = &Tcp{} + case "udp": + i = &Udp{} + case "SNAT": + if fam == unix.NFPROTO_IPV4 { + i = &NatIPv4MultiRangeCompat{} + } + case "DNAT": + switch fam { + case unix.NFPROTO_IPV4: + if rev == 0 { + i = &NatIPv4MultiRangeCompat{} + break + } + fallthrough + case unix.NFPROTO_IPV6: + switch rev { + case 1: + i = &NatRange{} + case 2: + i = &NatRange2{} + } + } + case "MASQUERADE": + switch fam { + case unix.NFPROTO_IPV4: + i = &NatIPv4MultiRangeCompat{} + } + case "REDIRECT": + switch fam { + case unix.NFPROTO_IPV4: + if rev == 0 { + i = &NatIPv4MultiRangeCompat{} + break + } + fallthrough + case unix.NFPROTO_IPV6: + i = &NatRange{} + } + } + if i == nil { + i = &Unknown{} + } + if err := i.unmarshal(fam, rev, data); err != nil { + return nil, err + } + return i, nil +} diff --git a/vendor/github.com/google/nftables/xt/match_addrtype.go b/vendor/github.com/google/nftables/xt/match_addrtype.go new file mode 100644 index 0000000000..3e21057a19 --- /dev/null +++ b/vendor/github.com/google/nftables/xt/match_addrtype.go @@ -0,0 +1,89 @@ +package xt + +import ( + "github.com/google/nftables/alignedbuff" +) + +// Rev. 0, see https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/xt_addrtype.h#L38 +type AddrType struct { + Source uint16 + Dest uint16 + InvertSource bool + InvertDest bool +} + +type AddrTypeFlags uint32 + +const ( + AddrTypeUnspec AddrTypeFlags = 1 << iota + AddrTypeUnicast + AddrTypeLocal + AddrTypeBroadcast + AddrTypeAnycast + AddrTypeMulticast + AddrTypeBlackhole + AddrTypeUnreachable + AddrTypeProhibit + AddrTypeThrow + AddrTypeNat + AddrTypeXresolve +) + +// See https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/xt_addrtype.h#L31 +type AddrTypeV1 struct { + Source uint16 + Dest uint16 + Flags AddrTypeFlags +} + +func (x *AddrType) marshal(fam TableFamily, rev uint32) ([]byte, error) { + ab := alignedbuff.New() + ab.PutUint16(x.Source) + ab.PutUint16(x.Dest) + putBool32(&ab, x.InvertSource) + putBool32(&ab, x.InvertDest) + return ab.Data(), nil +} + +func (x *AddrType) unmarshal(fam TableFamily, rev uint32, data []byte) error { + ab := alignedbuff.NewWithData(data) + var err error + if x.Source, err = ab.Uint16(); err != nil { + return nil + } + if x.Dest, err = ab.Uint16(); err != nil { + return nil + } + if x.InvertSource, err = bool32(&ab); err != nil { + return nil + } + if x.InvertDest, err = bool32(&ab); err != nil { + return nil + } + return nil +} + +func (x *AddrTypeV1) marshal(fam TableFamily, rev uint32) ([]byte, error) { + ab := alignedbuff.New() + ab.PutUint16(x.Source) + ab.PutUint16(x.Dest) + ab.PutUint32(uint32(x.Flags)) + return ab.Data(), nil +} + +func (x *AddrTypeV1) unmarshal(fam TableFamily, rev uint32, data []byte) error { + ab := alignedbuff.NewWithData(data) + var err error + if x.Source, err = ab.Uint16(); err != nil { + return nil + } + if x.Dest, err = ab.Uint16(); err != nil { + return nil + } + var flags uint32 + if flags, err = ab.Uint32(); err != nil { + return nil + } + x.Flags = AddrTypeFlags(flags) + return nil +} diff --git a/vendor/github.com/google/nftables/xt/match_conntrack.go b/vendor/github.com/google/nftables/xt/match_conntrack.go new file mode 100644 index 0000000000..69c51bd80c --- /dev/null +++ b/vendor/github.com/google/nftables/xt/match_conntrack.go @@ -0,0 +1,260 @@ +package xt + +import ( + "net" + + "github.com/google/nftables/alignedbuff" +) + +type ConntrackFlags uint16 + +const ( + ConntrackState ConntrackFlags = 1 << iota + ConntrackProto + ConntrackOrigSrc + ConntrackOrigDst + ConntrackReplSrc + ConntrackReplDst + ConntrackStatus + ConntrackExpires + ConntrackOrigSrcPort + ConntrackOrigDstPort + ConntrackReplSrcPort + ConntrackReplDstPrt + ConntrackDirection + ConntrackStateAlias +) + +type ConntrackMtinfoBase struct { + OrigSrcAddr net.IP + OrigSrcMask net.IPMask + OrigDstAddr net.IP + OrigDstMask net.IPMask + ReplSrcAddr net.IP + ReplSrcMask net.IPMask + ReplDstAddr net.IP + ReplDstMask net.IPMask + ExpiresMin uint32 + ExpiresMax uint32 + L4Proto uint16 + OrigSrcPort uint16 + OrigDstPort uint16 + ReplSrcPort uint16 + ReplDstPort uint16 + MatchFlags uint16 + InvertFlags uint16 +} + +// See https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/xt_conntrack.h#L38 +type ConntrackMtinfo1 struct { + ConntrackMtinfoBase + StateMask uint8 + StatusMask uint8 +} + +// See https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/xt_conntrack.h#L51 +type ConntrackMtinfo2 struct { + ConntrackMtinfoBase + StateMask uint16 + StatusMask uint16 +} + +// See https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/xt_conntrack.h#L64 +type ConntrackMtinfo3 struct { + ConntrackMtinfo2 + OrigSrcPortHigh uint16 + OrigDstPortHigh uint16 + ReplSrcPortHigh uint16 + ReplDstPortHigh uint16 +} + +func (x *ConntrackMtinfoBase) marshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error { + if err := putIPv46(ab, fam, x.OrigSrcAddr); err != nil { + return err + } + if err := putIPv46Mask(ab, fam, x.OrigSrcMask); err != nil { + return err + } + if err := putIPv46(ab, fam, x.OrigDstAddr); err != nil { + return err + } + if err := putIPv46Mask(ab, fam, x.OrigDstMask); err != nil { + return err + } + if err := putIPv46(ab, fam, x.ReplSrcAddr); err != nil { + return err + } + if err := putIPv46Mask(ab, fam, x.ReplSrcMask); err != nil { + return err + } + if err := putIPv46(ab, fam, x.ReplDstAddr); err != nil { + return err + } + if err := putIPv46Mask(ab, fam, x.ReplDstMask); err != nil { + return err + } + ab.PutUint32(x.ExpiresMin) + ab.PutUint32(x.ExpiresMax) + ab.PutUint16(x.L4Proto) + ab.PutUint16(x.OrigSrcPort) + ab.PutUint16(x.OrigDstPort) + ab.PutUint16(x.ReplSrcPort) + ab.PutUint16(x.ReplDstPort) + ab.PutUint16(x.MatchFlags) + ab.PutUint16(x.InvertFlags) + return nil +} + +func (x *ConntrackMtinfoBase) unmarshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error { + var err error + if x.OrigSrcAddr, err = iPv46(ab, fam); err != nil { + return err + } + if x.OrigSrcMask, err = iPv46Mask(ab, fam); err != nil { + return err + } + if x.OrigDstAddr, err = iPv46(ab, fam); err != nil { + return err + } + if x.OrigDstMask, err = iPv46Mask(ab, fam); err != nil { + return err + } + if x.ReplSrcAddr, err = iPv46(ab, fam); err != nil { + return err + } + if x.ReplSrcMask, err = iPv46Mask(ab, fam); err != nil { + return err + } + if x.ReplDstAddr, err = iPv46(ab, fam); err != nil { + return err + } + if x.ReplDstMask, err = iPv46Mask(ab, fam); err != nil { + return err + } + if x.ExpiresMin, err = ab.Uint32(); err != nil { + return err + } + if x.ExpiresMax, err = ab.Uint32(); err != nil { + return err + } + if x.L4Proto, err = ab.Uint16(); err != nil { + return err + } + if x.OrigSrcPort, err = ab.Uint16(); err != nil { + return err + } + if x.OrigDstPort, err = ab.Uint16(); err != nil { + return err + } + if x.ReplSrcPort, err = ab.Uint16(); err != nil { + return err + } + if x.ReplDstPort, err = ab.Uint16(); err != nil { + return err + } + if x.MatchFlags, err = ab.Uint16(); err != nil { + return err + } + if x.InvertFlags, err = ab.Uint16(); err != nil { + return err + } + return nil +} + +func (x *ConntrackMtinfo1) marshal(fam TableFamily, rev uint32) ([]byte, error) { + ab := alignedbuff.New() + if err := x.ConntrackMtinfoBase.marshalAB(fam, rev, &ab); err != nil { + return nil, err + } + ab.PutUint8(x.StateMask) + ab.PutUint8(x.StatusMask) + return ab.Data(), nil +} + +func (x *ConntrackMtinfo1) unmarshal(fam TableFamily, rev uint32, data []byte) error { + ab := alignedbuff.NewWithData(data) + var err error + if err = x.ConntrackMtinfoBase.unmarshalAB(fam, rev, &ab); err != nil { + return err + } + if x.StateMask, err = ab.Uint8(); err != nil { + return err + } + if x.StatusMask, err = ab.Uint8(); err != nil { + return err + } + return nil +} + +func (x *ConntrackMtinfo2) marshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error { + if err := x.ConntrackMtinfoBase.marshalAB(fam, rev, ab); err != nil { + return err + } + ab.PutUint16(x.StateMask) + ab.PutUint16(x.StatusMask) + return nil +} + +func (x *ConntrackMtinfo2) marshal(fam TableFamily, rev uint32) ([]byte, error) { + ab := alignedbuff.New() + if err := x.marshalAB(fam, rev, &ab); err != nil { + return nil, err + } + return ab.Data(), nil +} + +func (x *ConntrackMtinfo2) unmarshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error { + var err error + if err = x.ConntrackMtinfoBase.unmarshalAB(fam, rev, ab); err != nil { + return err + } + if x.StateMask, err = ab.Uint16(); err != nil { + return err + } + if x.StatusMask, err = ab.Uint16(); err != nil { + return err + } + return nil +} + +func (x *ConntrackMtinfo2) unmarshal(fam TableFamily, rev uint32, data []byte) error { + ab := alignedbuff.NewWithData(data) + var err error + if err = x.unmarshalAB(fam, rev, &ab); err != nil { + return err + } + return nil +} + +func (x *ConntrackMtinfo3) marshal(fam TableFamily, rev uint32) ([]byte, error) { + ab := alignedbuff.New() + if err := x.ConntrackMtinfo2.marshalAB(fam, rev, &ab); err != nil { + return nil, err + } + ab.PutUint16(x.OrigSrcPortHigh) + ab.PutUint16(x.OrigDstPortHigh) + ab.PutUint16(x.ReplSrcPortHigh) + ab.PutUint16(x.ReplDstPortHigh) + return ab.Data(), nil +} + +func (x *ConntrackMtinfo3) unmarshal(fam TableFamily, rev uint32, data []byte) error { + ab := alignedbuff.NewWithData(data) + var err error + if err = x.ConntrackMtinfo2.unmarshalAB(fam, rev, &ab); err != nil { + return err + } + if x.OrigSrcPortHigh, err = ab.Uint16(); err != nil { + return err + } + if x.OrigDstPortHigh, err = ab.Uint16(); err != nil { + return err + } + if x.ReplSrcPortHigh, err = ab.Uint16(); err != nil { + return err + } + if x.ReplDstPortHigh, err = ab.Uint16(); err != nil { + return err + } + return nil +} diff --git a/vendor/github.com/google/nftables/xt/match_tcp.go b/vendor/github.com/google/nftables/xt/match_tcp.go new file mode 100644 index 0000000000..d991f12767 --- /dev/null +++ b/vendor/github.com/google/nftables/xt/match_tcp.go @@ -0,0 +1,74 @@ +package xt + +import ( + "github.com/google/nftables/alignedbuff" +) + +// Tcp is the Match.Info payload for the tcp xtables extension +// (https://wiki.nftables.org/wiki-nftables/index.php/Supported_features_compared_to_xtables#tcp). +// +// See +// https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/xt_tcpudp.h#L8 +type Tcp struct { + SrcPorts [2]uint16 // min, max source port range + DstPorts [2]uint16 // min, max destination port range + Option uint8 // TCP option if non-zero + FlagsMask uint8 // TCP flags mask + FlagsCmp uint8 // TCP flags compare + InvFlags TcpInvFlagset // Inverse flags +} + +type TcpInvFlagset uint8 + +const ( + TcpInvSrcPorts TcpInvFlagset = 1 << iota + TcpInvDestPorts + TcpInvFlags + TcpInvOption + TcpInvMask TcpInvFlagset = (1 << iota) - 1 +) + +func (x *Tcp) marshal(fam TableFamily, rev uint32) ([]byte, error) { + ab := alignedbuff.New() + ab.PutUint16(x.SrcPorts[0]) + ab.PutUint16(x.SrcPorts[1]) + ab.PutUint16(x.DstPorts[0]) + ab.PutUint16(x.DstPorts[1]) + ab.PutUint8(x.Option) + ab.PutUint8(x.FlagsMask) + ab.PutUint8(x.FlagsCmp) + ab.PutUint8(byte(x.InvFlags)) + return ab.Data(), nil +} + +func (x *Tcp) unmarshal(fam TableFamily, rev uint32, data []byte) error { + ab := alignedbuff.NewWithData(data) + var err error + if x.SrcPorts[0], err = ab.Uint16(); err != nil { + return err + } + if x.SrcPorts[1], err = ab.Uint16(); err != nil { + return err + } + if x.DstPorts[0], err = ab.Uint16(); err != nil { + return err + } + if x.DstPorts[1], err = ab.Uint16(); err != nil { + return err + } + if x.Option, err = ab.Uint8(); err != nil { + return err + } + if x.FlagsMask, err = ab.Uint8(); err != nil { + return err + } + if x.FlagsCmp, err = ab.Uint8(); err != nil { + return err + } + var invFlags uint8 + if invFlags, err = ab.Uint8(); err != nil { + return err + } + x.InvFlags = TcpInvFlagset(invFlags) + return nil +} diff --git a/vendor/github.com/google/nftables/xt/match_udp.go b/vendor/github.com/google/nftables/xt/match_udp.go new file mode 100644 index 0000000000..68ce12a069 --- /dev/null +++ b/vendor/github.com/google/nftables/xt/match_udp.go @@ -0,0 +1,57 @@ +package xt + +import ( + "github.com/google/nftables/alignedbuff" +) + +// Tcp is the Match.Info payload for the tcp xtables extension +// (https://wiki.nftables.org/wiki-nftables/index.php/Supported_features_compared_to_xtables#tcp). +// +// See +// https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/xt_tcpudp.h#L25 +type Udp struct { + SrcPorts [2]uint16 // min, max source port range + DstPorts [2]uint16 // min, max destination port range + InvFlags UdpInvFlagset // Inverse flags +} + +type UdpInvFlagset uint8 + +const ( + UdpInvSrcPorts UdpInvFlagset = 1 << iota + UdpInvDestPorts + UdpInvMask UdpInvFlagset = (1 << iota) - 1 +) + +func (x *Udp) marshal(fam TableFamily, rev uint32) ([]byte, error) { + ab := alignedbuff.New() + ab.PutUint16(x.SrcPorts[0]) + ab.PutUint16(x.SrcPorts[1]) + ab.PutUint16(x.DstPorts[0]) + ab.PutUint16(x.DstPorts[1]) + ab.PutUint8(byte(x.InvFlags)) + return ab.Data(), nil +} + +func (x *Udp) unmarshal(fam TableFamily, rev uint32, data []byte) error { + ab := alignedbuff.NewWithData(data) + var err error + if x.SrcPorts[0], err = ab.Uint16(); err != nil { + return err + } + if x.SrcPorts[1], err = ab.Uint16(); err != nil { + return err + } + if x.DstPorts[0], err = ab.Uint16(); err != nil { + return err + } + if x.DstPorts[1], err = ab.Uint16(); err != nil { + return err + } + var invFlags uint8 + if invFlags, err = ab.Uint8(); err != nil { + return err + } + x.InvFlags = UdpInvFlagset(invFlags) + return nil +} diff --git a/vendor/github.com/google/nftables/xt/target_dnat.go b/vendor/github.com/google/nftables/xt/target_dnat.go new file mode 100644 index 0000000000..b54e8fbefb --- /dev/null +++ b/vendor/github.com/google/nftables/xt/target_dnat.go @@ -0,0 +1,106 @@ +package xt + +import ( + "net" + + "github.com/google/nftables/alignedbuff" +) + +type NatRangeFlags uint + +// See: https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/nf_nat.h#L8 +const ( + NatRangeMapIPs NatRangeFlags = (1 << iota) + NatRangeProtoSpecified + NatRangeProtoRandom + NatRangePersistent + NatRangeProtoRandomFully + NatRangeProtoOffset + NatRangeNetmap + + NatRangeMask NatRangeFlags = (1 << iota) - 1 + + NatRangeProtoRandomAll = NatRangeProtoRandom | NatRangeProtoRandomFully +) + +// see: https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/nf_nat.h#L38 +type NatRange struct { + Flags uint // sic! platform/arch/compiler-dependent uint size + MinIP net.IP // always taking up space for an IPv6 address + MaxIP net.IP // dito + MinPort uint16 + MaxPort uint16 +} + +// see: https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/nf_nat.h#L46 +type NatRange2 struct { + NatRange + BasePort uint16 +} + +func (x *NatRange) marshal(fam TableFamily, rev uint32) ([]byte, error) { + ab := alignedbuff.New() + if err := x.marshalAB(fam, rev, &ab); err != nil { + return nil, err + } + return ab.Data(), nil +} + +func (x *NatRange) marshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error { + ab.PutUint(x.Flags) + if err := putIPv46(ab, fam, x.MinIP); err != nil { + return err + } + if err := putIPv46(ab, fam, x.MaxIP); err != nil { + return err + } + ab.PutUint16BE(x.MinPort) + ab.PutUint16BE(x.MaxPort) + return nil +} + +func (x *NatRange) unmarshal(fam TableFamily, rev uint32, data []byte) error { + ab := alignedbuff.NewWithData(data) + return x.unmarshalAB(fam, rev, &ab) +} + +func (x *NatRange) unmarshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error { + var err error + if x.Flags, err = ab.Uint(); err != nil { + return err + } + if x.MinIP, err = iPv46(ab, fam); err != nil { + return err + } + if x.MaxIP, err = iPv46(ab, fam); err != nil { + return err + } + if x.MinPort, err = ab.Uint16BE(); err != nil { + return err + } + if x.MaxPort, err = ab.Uint16BE(); err != nil { + return err + } + return nil +} + +func (x *NatRange2) marshal(fam TableFamily, rev uint32) ([]byte, error) { + ab := alignedbuff.New() + if err := x.NatRange.marshalAB(fam, rev, &ab); err != nil { + return nil, err + } + ab.PutUint16BE(x.BasePort) + return ab.Data(), nil +} + +func (x *NatRange2) unmarshal(fam TableFamily, rev uint32, data []byte) error { + ab := alignedbuff.NewWithData(data) + var err error + if err = x.NatRange.unmarshalAB(fam, rev, &ab); err != nil { + return err + } + if x.BasePort, err = ab.Uint16BE(); err != nil { + return err + } + return nil +} diff --git a/vendor/github.com/google/nftables/xt/target_masquerade_ip.go b/vendor/github.com/google/nftables/xt/target_masquerade_ip.go new file mode 100644 index 0000000000..411d3beaa1 --- /dev/null +++ b/vendor/github.com/google/nftables/xt/target_masquerade_ip.go @@ -0,0 +1,86 @@ +package xt + +import ( + "errors" + "net" + + "github.com/google/nftables/alignedbuff" +) + +// See https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/nf_nat.h#L25 +type NatIPv4Range struct { + Flags uint // sic! + MinIP net.IP + MaxIP net.IP + MinPort uint16 + MaxPort uint16 +} + +// NatIPv4MultiRangeCompat despite being a slice of NAT IPv4 ranges is currently allowed to +// only hold exactly one element. +// +// See https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/nf_nat.h#L33 +type NatIPv4MultiRangeCompat []NatIPv4Range + +func (x *NatIPv4MultiRangeCompat) marshal(fam TableFamily, rev uint32) ([]byte, error) { + ab := alignedbuff.New() + if len(*x) != 1 { + return nil, errors.New("MasqueradeIp must contain exactly one NatIPv4Range") + } + ab.PutUint(uint(len(*x))) + for _, nat := range *x { + if err := nat.marshalAB(fam, rev, &ab); err != nil { + return nil, err + } + } + return ab.Data(), nil +} + +func (x *NatIPv4MultiRangeCompat) unmarshal(fam TableFamily, rev uint32, data []byte) error { + ab := alignedbuff.NewWithData(data) + l, err := ab.Uint() + if err != nil { + return err + } + nats := make(NatIPv4MultiRangeCompat, l) + for l > 0 { + l-- + if err := nats[l].unmarshalAB(fam, rev, &ab); err != nil { + return err + } + } + *x = nats + return nil +} + +func (x *NatIPv4Range) marshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error { + ab.PutUint(x.Flags) + ab.PutBytesAligned32(x.MinIP.To4(), 4) + ab.PutBytesAligned32(x.MaxIP.To4(), 4) + ab.PutUint16BE(x.MinPort) + ab.PutUint16BE(x.MaxPort) + return nil +} + +func (x *NatIPv4Range) unmarshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error { + var err error + if x.Flags, err = ab.Uint(); err != nil { + return err + } + var ip []byte + if ip, err = ab.BytesAligned32(4); err != nil { + return err + } + x.MinIP = net.IP(ip) + if ip, err = ab.BytesAligned32(4); err != nil { + return err + } + x.MaxIP = net.IP(ip) + if x.MinPort, err = ab.Uint16BE(); err != nil { + return err + } + if x.MaxPort, err = ab.Uint16BE(); err != nil { + return err + } + return nil +} diff --git a/vendor/github.com/google/nftables/xt/unknown.go b/vendor/github.com/google/nftables/xt/unknown.go new file mode 100644 index 0000000000..c648307c5c --- /dev/null +++ b/vendor/github.com/google/nftables/xt/unknown.go @@ -0,0 +1,17 @@ +package xt + +// Unknown represents the bytes Info payload for unknown Info types where no +// dedicated match/target info type has (yet) been defined. +type Unknown []byte + +func (x *Unknown) marshal(fam TableFamily, rev uint32) ([]byte, error) { + // In case of unknown payload we assume its creator knows what she/he does + // and thus we don't do any alignment padding. Just take the payload "as + // is". + return *x, nil +} + +func (x *Unknown) unmarshal(fam TableFamily, rev uint32, data []byte) error { + *x = data + return nil +} diff --git a/vendor/github.com/google/nftables/xt/util.go b/vendor/github.com/google/nftables/xt/util.go new file mode 100644 index 0000000000..673ac54f76 --- /dev/null +++ b/vendor/github.com/google/nftables/xt/util.go @@ -0,0 +1,64 @@ +package xt + +import ( + "fmt" + "net" + + "github.com/google/nftables/alignedbuff" + "golang.org/x/sys/unix" +) + +func bool32(ab *alignedbuff.AlignedBuff) (bool, error) { + v, err := ab.Uint32() + if err != nil { + return false, err + } + if v != 0 { + return true, nil + } + return false, nil +} + +func putBool32(ab *alignedbuff.AlignedBuff, b bool) { + if b { + ab.PutUint32(1) + return + } + ab.PutUint32(0) +} + +func iPv46(ab *alignedbuff.AlignedBuff, fam TableFamily) (net.IP, error) { + ip, err := ab.BytesAligned32(16) + if err != nil { + return nil, err + } + switch fam { + case unix.NFPROTO_IPV4: + return net.IP(ip[:4]), nil + case unix.NFPROTO_IPV6: + return net.IP(ip), nil + default: + return nil, fmt.Errorf("unmarshal IP: unsupported table family %d", fam) + } +} + +func iPv46Mask(ab *alignedbuff.AlignedBuff, fam TableFamily) (net.IPMask, error) { + v, err := iPv46(ab, fam) + return net.IPMask(v), err +} + +func putIPv46(ab *alignedbuff.AlignedBuff, fam TableFamily, ip net.IP) error { + switch fam { + case unix.NFPROTO_IPV4: + ab.PutBytesAligned32(ip.To4(), 16) + case unix.NFPROTO_IPV6: + ab.PutBytesAligned32(ip.To16(), 16) + default: + return fmt.Errorf("marshal IP: unsupported table family %d", fam) + } + return nil +} + +func putIPv46Mask(ab *alignedbuff.AlignedBuff, fam TableFamily, mask net.IPMask) error { + return putIPv46(ab, fam, net.IP(mask)) +} diff --git a/vendor/github.com/google/nftables/xt/xt.go b/vendor/github.com/google/nftables/xt/xt.go new file mode 100644 index 0000000000..d8977c1d04 --- /dev/null +++ b/vendor/github.com/google/nftables/xt/xt.go @@ -0,0 +1,48 @@ +/* +Package xt implements dedicated types for (some) of the "Info" payload in Match +and Target expressions that bridge between the nftables and xtables worlds. + +Bridging between the more unified world of nftables and the slightly +heterogenous world of xtables comes with some caveats. Unmarshalling the +extension/translation information in Match and Target expressions requires +information about the table family the information belongs to, as well as type +and type revision information. In consequence, unmarshalling the Match and +Target Info field payloads often (but not necessarily always) require the table +family and revision information, so it gets passed to the type-specific +unmarshallers. + +To complicate things more, even marshalling requires knowledge about the +enclosing table family. The NatRange/NatRange2 types are an example, where it is +necessary to differentiate between IPv4 and IPv6 address marshalling. Due to +Go's net.IP habit to normally store IPv4 addresses as IPv4-compatible IPv6 +addresses (see also RFC 4291, section 2.5.5.1) marshalling must be handled +differently in the context of an IPv6 table compared to an IPv4 table. In an +IPv4 table, an IPv4-compatible IPv6 address must be marshalled as a 32bit +address, whereas in an IPv6 table the IPv4 address must be marshalled as an +128bit IPv4-compatible IPv6 address. Not relying on heuristics here we avoid +behavior unexpected and most probably unknown to our API users. The net.IP habit +of storing IPv4 addresses in two different storage formats is already a source +for trouble, especially when comparing net.IPs from different Go module sources. +We won't add to this confusion. (...or maybe we can, because of it?) + +An important property of all types of Info extension/translation payloads is +that their marshalling and unmarshalling doesn't follow netlink's TLV +(tag-length-value) architecture. Instead, Info payloads a basically plain binary +blobs of their respective type-specific data structures, so host +platform/architecture alignment and data type sizes apply. The alignedbuff +package implements the different required data types alignments. + +Please note that Info payloads are always padded at their end to the next uint64 +alignment. Kernel code is checking for the padded payload size and will reject +payloads not correctly padded at their ends. + +Most of the time, we find explifcitly sized (unsigned integer) data types. +However, there are notable exceptions where "unsigned int" is used: on 64bit +platforms this mostly translates into 32bit(!). This differs from Go mapping +uint to uint64 instead. This package currently clamps its mapping of C's +"unsigned int" to Go's uint32 for marshalling and unmarshalling. If in the +future 128bit platforms with a differently sized C unsigned int should come into +production, then the alignedbuff package will need to be adapted accordingly, as +it abstracts away this data type handling. +*/ +package xt diff --git a/vendor/github.com/mdlayher/netlink/nltest/errors_others.go b/vendor/github.com/mdlayher/netlink/nltest/errors_others.go new file mode 100644 index 0000000000..3a29c9b1ac --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/nltest/errors_others.go @@ -0,0 +1,8 @@ +//go:build plan9 || windows +// +build plan9 windows + +package nltest + +func isSyscallError(_ error) bool { + return false +} diff --git a/vendor/github.com/mdlayher/netlink/nltest/errors_unix.go b/vendor/github.com/mdlayher/netlink/nltest/errors_unix.go new file mode 100644 index 0000000000..f54403bb08 --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/nltest/errors_unix.go @@ -0,0 +1,11 @@ +//go:build !plan9 && !windows +// +build !plan9,!windows + +package nltest + +import "golang.org/x/sys/unix" + +func isSyscallError(err error) bool { + _, ok := err.(unix.Errno) + return ok +} diff --git a/vendor/github.com/mdlayher/netlink/nltest/nltest.go b/vendor/github.com/mdlayher/netlink/nltest/nltest.go new file mode 100644 index 0000000000..2065bab02a --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/nltest/nltest.go @@ -0,0 +1,207 @@ +// Package nltest provides utilities for netlink testing. +package nltest + +import ( + "fmt" + "io" + "os" + + "github.com/mdlayher/netlink" + "github.com/mdlayher/netlink/nlenc" +) + +// PID is the netlink header PID value assigned by nltest. +const PID = 1 + +// MustMarshalAttributes marshals a slice of netlink.Attributes to their binary +// format, but panics if any errors occur. +func MustMarshalAttributes(attrs []netlink.Attribute) []byte { + b, err := netlink.MarshalAttributes(attrs) + if err != nil { + panic(fmt.Sprintf("failed to marshal attributes to binary: %v", err)) + } + + return b +} + +// Multipart sends a slice of netlink.Messages to the caller as a +// netlink multi-part message. If less than two messages are present, +// the messages are not altered. +func Multipart(msgs []netlink.Message) ([]netlink.Message, error) { + if len(msgs) < 2 { + return msgs, nil + } + + for i := range msgs { + // Last message has header type "done" in addition to multi-part flag. + if i == len(msgs)-1 { + msgs[i].Header.Type = netlink.Done + } + + msgs[i].Header.Flags |= netlink.Multi + } + + return msgs, nil +} + +// Error returns a netlink error to the caller with the specified error +// number, in the body of the specified request message. +func Error(number int, reqs []netlink.Message) ([]netlink.Message, error) { + req := reqs[0] + req.Header.Length += 4 + req.Header.Type = netlink.Error + + errno := -1 * int32(number) + req.Data = append(nlenc.Int32Bytes(errno), req.Data...) + + return []netlink.Message{req}, nil +} + +// A Func is a function that can be used to test netlink.Conn interactions. +// The function can choose to return zero or more netlink messages, or an +// error if needed. +// +// For a netlink request/response interaction, a request req is populated by +// netlink.Conn.Send and passed to the function. +// +// For multicast interactions, an empty request req is passed to the function +// when netlink.Conn.Receive is called. +// +// If a Func returns an error, the error will be returned as-is to the caller. +// If no messages and io.EOF are returned, no messages and no error will be +// returned to the caller, simulating a multi-part message with no data. +type Func func(req []netlink.Message) ([]netlink.Message, error) + +// Dial sets up a netlink.Conn for testing using the specified Func. All requests +// sent from the connection will be passed to the Func. The connection should be +// closed as usual when it is no longer needed. +func Dial(fn Func) *netlink.Conn { + sock := &socket{ + fn: fn, + } + + return netlink.NewConn(sock, PID) +} + +// CheckRequest returns a Func that verifies that each message in an incoming +// request has the specified netlink header type and flags in the same slice +// position index, and then passes the request through to fn. +// +// The length of the types and flags slices must match the number of requests +// passed to the returned Func, or CheckRequest will panic. +// +// As an example: +// - types[0] and flags[0] will be checked against reqs[0] +// - types[1] and flags[1] will be checked against reqs[1] +// - ... and so on +// +// If an element of types or flags is set to the zero value, that check will +// be skipped for the request message that occurs at the same index. +// +// As an example, if types[0] is 0 and reqs[0].Header.Type is 1, the check will +// succeed because types[0] was not specified. +func CheckRequest(types []netlink.HeaderType, flags []netlink.HeaderFlags, fn Func) Func { + if len(types) != len(flags) { + panicf("nltest: CheckRequest called with mismatched types and flags slice lengths: %d != %d", + len(types), len(flags)) + } + + return func(req []netlink.Message) ([]netlink.Message, error) { + if len(types) != len(req) { + panicf("nltest: CheckRequest function invoked types/flags and request message slice lengths: %d != %d", + len(types), len(req)) + } + + for i := range req { + if want, got := types[i], req[i].Header.Type; types[i] != 0 && want != got { + return nil, fmt.Errorf("nltest: unexpected netlink header type: %s, want: %s", got, want) + } + + if want, got := flags[i], req[i].Header.Flags; flags[i] != 0 && want != got { + return nil, fmt.Errorf("nltest: unexpected netlink header flags: %s, want: %s", got, want) + } + } + + return fn(req) + } +} + +// A socket is a netlink.Socket used for testing. +type socket struct { + fn Func + + msgs []netlink.Message + err error +} + +func (c *socket) Close() error { return nil } + +func (c *socket) SendMessages(messages []netlink.Message) error { + msgs, err := c.fn(messages) + c.msgs = append(c.msgs, msgs...) + c.err = err + return nil +} + +func (c *socket) Send(m netlink.Message) error { + c.msgs, c.err = c.fn([]netlink.Message{m}) + return nil +} + +func (c *socket) Receive() ([]netlink.Message, error) { + // No messages set by Send means that we are emulating a + // multicast response or an error occurred. + if len(c.msgs) == 0 { + switch c.err { + case nil: + // No error, simulate multicast, but also return EOF to simulate + // no replies if needed. + msgs, err := c.fn(nil) + if err == io.EOF { + err = nil + } + + return msgs, err + case io.EOF: + // EOF, simulate no replies in multi-part message. + return nil, nil + } + + // If the error is a system call error, wrap it in os.NewSyscallError + // to simulate what the Linux netlink.Conn does. + if isSyscallError(c.err) { + return nil, os.NewSyscallError("recvmsg", c.err) + } + + // Some generic error occurred and should be passed to the caller. + return nil, c.err + } + + // Detect multi-part messages. + var multi bool + for _, m := range c.msgs { + if m.Header.Flags&netlink.Multi != 0 && m.Header.Type != netlink.Done { + multi = true + } + } + + // When a multi-part message is detected, return all messages except for the + // final "multi-part done", so that a second call to Receive from netlink.Conn + // will drain that message. + if multi { + last := c.msgs[len(c.msgs)-1] + ret := c.msgs[:len(c.msgs)-1] + c.msgs = []netlink.Message{last} + + return ret, c.err + } + + msgs, err := c.msgs, c.err + c.msgs, c.err = nil, nil + + return msgs, err +} + +func panicf(format string, a ...interface{}) { + panic(fmt.Sprintf(format, a...)) +} diff --git a/vendor/github.com/stretchr/testify/assert/assertion_compare.go b/vendor/github.com/stretchr/testify/assert/assertion_compare.go index 95d8e59da6..b774da88d8 100644 --- a/vendor/github.com/stretchr/testify/assert/assertion_compare.go +++ b/vendor/github.com/stretchr/testify/assert/assertion_compare.go @@ -352,9 +352,9 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) { // Greater asserts that the first element is greater than the second // -// assert.Greater(t, 2, 1) -// assert.Greater(t, float64(2), float64(1)) -// assert.Greater(t, "b", "a") +// assert.Greater(t, 2, 1) +// assert.Greater(t, float64(2), float64(1)) +// assert.Greater(t, "b", "a") func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -364,10 +364,10 @@ func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface // GreaterOrEqual asserts that the first element is greater than or equal to the second // -// assert.GreaterOrEqual(t, 2, 1) -// assert.GreaterOrEqual(t, 2, 2) -// assert.GreaterOrEqual(t, "b", "a") -// assert.GreaterOrEqual(t, "b", "b") +// assert.GreaterOrEqual(t, 2, 1) +// assert.GreaterOrEqual(t, 2, 2) +// assert.GreaterOrEqual(t, "b", "a") +// assert.GreaterOrEqual(t, "b", "b") func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -377,9 +377,9 @@ func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...in // Less asserts that the first element is less than the second // -// assert.Less(t, 1, 2) -// assert.Less(t, float64(1), float64(2)) -// assert.Less(t, "a", "b") +// assert.Less(t, 1, 2) +// assert.Less(t, float64(1), float64(2)) +// assert.Less(t, "a", "b") func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -389,10 +389,10 @@ func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) // LessOrEqual asserts that the first element is less than or equal to the second // -// assert.LessOrEqual(t, 1, 2) -// assert.LessOrEqual(t, 2, 2) -// assert.LessOrEqual(t, "a", "b") -// assert.LessOrEqual(t, "b", "b") +// assert.LessOrEqual(t, 1, 2) +// assert.LessOrEqual(t, 2, 2) +// assert.LessOrEqual(t, "a", "b") +// assert.LessOrEqual(t, "b", "b") func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -402,8 +402,8 @@ func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...inter // Positive asserts that the specified element is positive // -// assert.Positive(t, 1) -// assert.Positive(t, 1.23) +// assert.Positive(t, 1) +// assert.Positive(t, 1.23) func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -414,8 +414,8 @@ func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool { // Negative asserts that the specified element is negative // -// assert.Negative(t, -1) -// assert.Negative(t, -1.23) +// assert.Negative(t, -1) +// assert.Negative(t, -1.23) func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() diff --git a/vendor/github.com/stretchr/testify/assert/assertion_format.go b/vendor/github.com/stretchr/testify/assert/assertion_format.go index 7880b8f943..84dbd6c790 100644 --- a/vendor/github.com/stretchr/testify/assert/assertion_format.go +++ b/vendor/github.com/stretchr/testify/assert/assertion_format.go @@ -22,9 +22,9 @@ func Conditionf(t TestingT, comp Comparison, msg string, args ...interface{}) bo // Containsf asserts that the specified string, list(array, slice...) or map contains the // specified substring or element. // -// assert.Containsf(t, "Hello World", "World", "error message %s", "formatted") -// assert.Containsf(t, ["Hello", "World"], "World", "error message %s", "formatted") -// assert.Containsf(t, {"Hello": "World"}, "Hello", "error message %s", "formatted") +// assert.Containsf(t, "Hello World", "World", "error message %s", "formatted") +// assert.Containsf(t, ["Hello", "World"], "World", "error message %s", "formatted") +// assert.Containsf(t, {"Hello": "World"}, "Hello", "error message %s", "formatted") func Containsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -56,7 +56,7 @@ func ElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string // Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either // a slice or a channel with len == 0. // -// assert.Emptyf(t, obj, "error message %s", "formatted") +// assert.Emptyf(t, obj, "error message %s", "formatted") func Emptyf(t TestingT, object interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -66,7 +66,7 @@ func Emptyf(t TestingT, object interface{}, msg string, args ...interface{}) boo // Equalf asserts that two objects are equal. // -// assert.Equalf(t, 123, 123, "error message %s", "formatted") +// assert.Equalf(t, 123, 123, "error message %s", "formatted") // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). Function equality @@ -81,8 +81,8 @@ func Equalf(t TestingT, expected interface{}, actual interface{}, msg string, ar // EqualErrorf asserts that a function returned an error (i.e. not `nil`) // and that it is equal to the provided error. // -// actualObj, err := SomeFunction() -// assert.EqualErrorf(t, err, expectedErrorString, "error message %s", "formatted") +// actualObj, err := SomeFunction() +// assert.EqualErrorf(t, err, expectedErrorString, "error message %s", "formatted") func EqualErrorf(t TestingT, theError error, errString string, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -90,10 +90,27 @@ func EqualErrorf(t TestingT, theError error, errString string, msg string, args return EqualError(t, theError, errString, append([]interface{}{msg}, args...)...) } +// EqualExportedValuesf asserts that the types of two objects are equal and their public +// fields are also equal. This is useful for comparing structs that have private fields +// that could potentially differ. +// +// type S struct { +// Exported int +// notExported int +// } +// assert.EqualExportedValuesf(t, S{1, 2}, S{1, 3}, "error message %s", "formatted") => true +// assert.EqualExportedValuesf(t, S{1, 2}, S{2, 3}, "error message %s", "formatted") => false +func EqualExportedValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return EqualExportedValues(t, expected, actual, append([]interface{}{msg}, args...)...) +} + // EqualValuesf asserts that two objects are equal or convertable to the same types // and equal. // -// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted") +// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted") func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -103,10 +120,10 @@ func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg stri // Errorf asserts that a function returned an error (i.e. not `nil`). // -// actualObj, err := SomeFunction() -// if assert.Errorf(t, err, "error message %s", "formatted") { -// assert.Equal(t, expectedErrorf, err) -// } +// actualObj, err := SomeFunction() +// if assert.Errorf(t, err, "error message %s", "formatted") { +// assert.Equal(t, expectedErrorf, err) +// } func Errorf(t TestingT, err error, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -126,8 +143,8 @@ func ErrorAsf(t TestingT, err error, target interface{}, msg string, args ...int // ErrorContainsf asserts that a function returned an error (i.e. not `nil`) // and that the error contains the specified substring. // -// actualObj, err := SomeFunction() -// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted") +// actualObj, err := SomeFunction() +// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted") func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -147,7 +164,7 @@ func ErrorIsf(t TestingT, err error, target error, msg string, args ...interface // Eventuallyf asserts that given condition will be met in waitFor time, // periodically checking target function each tick. // -// assert.Eventuallyf(t, func() bool { return true; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") +// assert.Eventuallyf(t, func() bool { return true; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") func Eventuallyf(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -155,9 +172,34 @@ func Eventuallyf(t TestingT, condition func() bool, waitFor time.Duration, tick return Eventually(t, condition, waitFor, tick, append([]interface{}{msg}, args...)...) } +// EventuallyWithTf asserts that given condition will be met in waitFor time, +// periodically checking target function each tick. In contrast to Eventually, +// it supplies a CollectT to the condition function, so that the condition +// function can use the CollectT to call other assertions. +// The condition is considered "met" if no errors are raised in a tick. +// The supplied CollectT collects all errors from one tick (if there are any). +// If the condition is not met before waitFor, the collected errors of +// the last tick are copied to t. +// +// externalValue := false +// go func() { +// time.Sleep(8*time.Second) +// externalValue = true +// }() +// assert.EventuallyWithTf(t, func(c *assert.CollectT, "error message %s", "formatted") { +// // add assertions as needed; any assertion failure will fail the current tick +// assert.True(c, externalValue, "expected 'externalValue' to be true") +// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") +func EventuallyWithTf(t TestingT, condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return EventuallyWithT(t, condition, waitFor, tick, append([]interface{}{msg}, args...)...) +} + // Exactlyf asserts that two objects are equal in value and type. // -// assert.Exactlyf(t, int32(123), int64(123), "error message %s", "formatted") +// assert.Exactlyf(t, int32(123), int64(123), "error message %s", "formatted") func Exactlyf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -183,7 +225,7 @@ func FailNowf(t TestingT, failureMessage string, msg string, args ...interface{} // Falsef asserts that the specified value is false. // -// assert.Falsef(t, myBool, "error message %s", "formatted") +// assert.Falsef(t, myBool, "error message %s", "formatted") func Falsef(t TestingT, value bool, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -202,9 +244,9 @@ func FileExistsf(t TestingT, path string, msg string, args ...interface{}) bool // Greaterf asserts that the first element is greater than the second // -// assert.Greaterf(t, 2, 1, "error message %s", "formatted") -// assert.Greaterf(t, float64(2), float64(1), "error message %s", "formatted") -// assert.Greaterf(t, "b", "a", "error message %s", "formatted") +// assert.Greaterf(t, 2, 1, "error message %s", "formatted") +// assert.Greaterf(t, float64(2), float64(1), "error message %s", "formatted") +// assert.Greaterf(t, "b", "a", "error message %s", "formatted") func Greaterf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -214,10 +256,10 @@ func Greaterf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...in // GreaterOrEqualf asserts that the first element is greater than or equal to the second // -// assert.GreaterOrEqualf(t, 2, 1, "error message %s", "formatted") -// assert.GreaterOrEqualf(t, 2, 2, "error message %s", "formatted") -// assert.GreaterOrEqualf(t, "b", "a", "error message %s", "formatted") -// assert.GreaterOrEqualf(t, "b", "b", "error message %s", "formatted") +// assert.GreaterOrEqualf(t, 2, 1, "error message %s", "formatted") +// assert.GreaterOrEqualf(t, 2, 2, "error message %s", "formatted") +// assert.GreaterOrEqualf(t, "b", "a", "error message %s", "formatted") +// assert.GreaterOrEqualf(t, "b", "b", "error message %s", "formatted") func GreaterOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -228,7 +270,7 @@ func GreaterOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, arg // HTTPBodyContainsf asserts that a specified handler returns a // body that contains a string. // -// assert.HTTPBodyContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") +// assert.HTTPBodyContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func HTTPBodyContainsf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) bool { @@ -241,7 +283,7 @@ func HTTPBodyContainsf(t TestingT, handler http.HandlerFunc, method string, url // HTTPBodyNotContainsf asserts that a specified handler returns a // body that does not contain a string. // -// assert.HTTPBodyNotContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") +// assert.HTTPBodyNotContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func HTTPBodyNotContainsf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) bool { @@ -253,7 +295,7 @@ func HTTPBodyNotContainsf(t TestingT, handler http.HandlerFunc, method string, u // HTTPErrorf asserts that a specified handler returns an error status code. // -// assert.HTTPErrorf(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// assert.HTTPErrorf(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func HTTPErrorf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool { @@ -265,7 +307,7 @@ func HTTPErrorf(t TestingT, handler http.HandlerFunc, method string, url string, // HTTPRedirectf asserts that a specified handler returns a redirect status code. // -// assert.HTTPRedirectf(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// assert.HTTPRedirectf(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func HTTPRedirectf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool { @@ -277,7 +319,7 @@ func HTTPRedirectf(t TestingT, handler http.HandlerFunc, method string, url stri // HTTPStatusCodef asserts that a specified handler returns a specified status code. // -// assert.HTTPStatusCodef(t, myHandler, "GET", "/notImplemented", nil, 501, "error message %s", "formatted") +// assert.HTTPStatusCodef(t, myHandler, "GET", "/notImplemented", nil, 501, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func HTTPStatusCodef(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, statuscode int, msg string, args ...interface{}) bool { @@ -289,7 +331,7 @@ func HTTPStatusCodef(t TestingT, handler http.HandlerFunc, method string, url st // HTTPSuccessf asserts that a specified handler returns a success status code. // -// assert.HTTPSuccessf(t, myHandler, "POST", "http://www.google.com", nil, "error message %s", "formatted") +// assert.HTTPSuccessf(t, myHandler, "POST", "http://www.google.com", nil, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func HTTPSuccessf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool { @@ -301,7 +343,7 @@ func HTTPSuccessf(t TestingT, handler http.HandlerFunc, method string, url strin // Implementsf asserts that an object is implemented by the specified interface. // -// assert.Implementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted") +// assert.Implementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted") func Implementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -311,7 +353,7 @@ func Implementsf(t TestingT, interfaceObject interface{}, object interface{}, ms // InDeltaf asserts that the two numerals are within delta of each other. // -// assert.InDeltaf(t, math.Pi, 22/7.0, 0.01, "error message %s", "formatted") +// assert.InDeltaf(t, math.Pi, 22/7.0, 0.01, "error message %s", "formatted") func InDeltaf(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -353,9 +395,9 @@ func InEpsilonSlicef(t TestingT, expected interface{}, actual interface{}, epsil // IsDecreasingf asserts that the collection is decreasing // -// assert.IsDecreasingf(t, []int{2, 1, 0}, "error message %s", "formatted") -// assert.IsDecreasingf(t, []float{2, 1}, "error message %s", "formatted") -// assert.IsDecreasingf(t, []string{"b", "a"}, "error message %s", "formatted") +// assert.IsDecreasingf(t, []int{2, 1, 0}, "error message %s", "formatted") +// assert.IsDecreasingf(t, []float{2, 1}, "error message %s", "formatted") +// assert.IsDecreasingf(t, []string{"b", "a"}, "error message %s", "formatted") func IsDecreasingf(t TestingT, object interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -365,9 +407,9 @@ func IsDecreasingf(t TestingT, object interface{}, msg string, args ...interface // IsIncreasingf asserts that the collection is increasing // -// assert.IsIncreasingf(t, []int{1, 2, 3}, "error message %s", "formatted") -// assert.IsIncreasingf(t, []float{1, 2}, "error message %s", "formatted") -// assert.IsIncreasingf(t, []string{"a", "b"}, "error message %s", "formatted") +// assert.IsIncreasingf(t, []int{1, 2, 3}, "error message %s", "formatted") +// assert.IsIncreasingf(t, []float{1, 2}, "error message %s", "formatted") +// assert.IsIncreasingf(t, []string{"a", "b"}, "error message %s", "formatted") func IsIncreasingf(t TestingT, object interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -377,9 +419,9 @@ func IsIncreasingf(t TestingT, object interface{}, msg string, args ...interface // IsNonDecreasingf asserts that the collection is not decreasing // -// assert.IsNonDecreasingf(t, []int{1, 1, 2}, "error message %s", "formatted") -// assert.IsNonDecreasingf(t, []float{1, 2}, "error message %s", "formatted") -// assert.IsNonDecreasingf(t, []string{"a", "b"}, "error message %s", "formatted") +// assert.IsNonDecreasingf(t, []int{1, 1, 2}, "error message %s", "formatted") +// assert.IsNonDecreasingf(t, []float{1, 2}, "error message %s", "formatted") +// assert.IsNonDecreasingf(t, []string{"a", "b"}, "error message %s", "formatted") func IsNonDecreasingf(t TestingT, object interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -389,9 +431,9 @@ func IsNonDecreasingf(t TestingT, object interface{}, msg string, args ...interf // IsNonIncreasingf asserts that the collection is not increasing // -// assert.IsNonIncreasingf(t, []int{2, 1, 1}, "error message %s", "formatted") -// assert.IsNonIncreasingf(t, []float{2, 1}, "error message %s", "formatted") -// assert.IsNonIncreasingf(t, []string{"b", "a"}, "error message %s", "formatted") +// assert.IsNonIncreasingf(t, []int{2, 1, 1}, "error message %s", "formatted") +// assert.IsNonIncreasingf(t, []float{2, 1}, "error message %s", "formatted") +// assert.IsNonIncreasingf(t, []string{"b", "a"}, "error message %s", "formatted") func IsNonIncreasingf(t TestingT, object interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -409,7 +451,7 @@ func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg strin // JSONEqf asserts that two JSON strings are equivalent. // -// assert.JSONEqf(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted") +// assert.JSONEqf(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted") func JSONEqf(t TestingT, expected string, actual string, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -420,7 +462,7 @@ func JSONEqf(t TestingT, expected string, actual string, msg string, args ...int // Lenf asserts that the specified object has specific length. // Lenf also fails if the object has a type that len() not accept. // -// assert.Lenf(t, mySlice, 3, "error message %s", "formatted") +// assert.Lenf(t, mySlice, 3, "error message %s", "formatted") func Lenf(t TestingT, object interface{}, length int, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -430,9 +472,9 @@ func Lenf(t TestingT, object interface{}, length int, msg string, args ...interf // Lessf asserts that the first element is less than the second // -// assert.Lessf(t, 1, 2, "error message %s", "formatted") -// assert.Lessf(t, float64(1), float64(2), "error message %s", "formatted") -// assert.Lessf(t, "a", "b", "error message %s", "formatted") +// assert.Lessf(t, 1, 2, "error message %s", "formatted") +// assert.Lessf(t, float64(1), float64(2), "error message %s", "formatted") +// assert.Lessf(t, "a", "b", "error message %s", "formatted") func Lessf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -442,10 +484,10 @@ func Lessf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...inter // LessOrEqualf asserts that the first element is less than or equal to the second // -// assert.LessOrEqualf(t, 1, 2, "error message %s", "formatted") -// assert.LessOrEqualf(t, 2, 2, "error message %s", "formatted") -// assert.LessOrEqualf(t, "a", "b", "error message %s", "formatted") -// assert.LessOrEqualf(t, "b", "b", "error message %s", "formatted") +// assert.LessOrEqualf(t, 1, 2, "error message %s", "formatted") +// assert.LessOrEqualf(t, 2, 2, "error message %s", "formatted") +// assert.LessOrEqualf(t, "a", "b", "error message %s", "formatted") +// assert.LessOrEqualf(t, "b", "b", "error message %s", "formatted") func LessOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -455,8 +497,8 @@ func LessOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args . // Negativef asserts that the specified element is negative // -// assert.Negativef(t, -1, "error message %s", "formatted") -// assert.Negativef(t, -1.23, "error message %s", "formatted") +// assert.Negativef(t, -1, "error message %s", "formatted") +// assert.Negativef(t, -1.23, "error message %s", "formatted") func Negativef(t TestingT, e interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -467,7 +509,7 @@ func Negativef(t TestingT, e interface{}, msg string, args ...interface{}) bool // Neverf asserts that the given condition doesn't satisfy in waitFor time, // periodically checking the target function each tick. // -// assert.Neverf(t, func() bool { return false; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") +// assert.Neverf(t, func() bool { return false; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") func Neverf(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -477,7 +519,7 @@ func Neverf(t TestingT, condition func() bool, waitFor time.Duration, tick time. // Nilf asserts that the specified object is nil. // -// assert.Nilf(t, err, "error message %s", "formatted") +// assert.Nilf(t, err, "error message %s", "formatted") func Nilf(t TestingT, object interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -496,10 +538,10 @@ func NoDirExistsf(t TestingT, path string, msg string, args ...interface{}) bool // NoErrorf asserts that a function returned no error (i.e. `nil`). // -// actualObj, err := SomeFunction() -// if assert.NoErrorf(t, err, "error message %s", "formatted") { -// assert.Equal(t, expectedObj, actualObj) -// } +// actualObj, err := SomeFunction() +// if assert.NoErrorf(t, err, "error message %s", "formatted") { +// assert.Equal(t, expectedObj, actualObj) +// } func NoErrorf(t TestingT, err error, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -519,9 +561,9 @@ func NoFileExistsf(t TestingT, path string, msg string, args ...interface{}) boo // NotContainsf asserts that the specified string, list(array, slice...) or map does NOT contain the // specified substring or element. // -// assert.NotContainsf(t, "Hello World", "Earth", "error message %s", "formatted") -// assert.NotContainsf(t, ["Hello", "World"], "Earth", "error message %s", "formatted") -// assert.NotContainsf(t, {"Hello": "World"}, "Earth", "error message %s", "formatted") +// assert.NotContainsf(t, "Hello World", "Earth", "error message %s", "formatted") +// assert.NotContainsf(t, ["Hello", "World"], "Earth", "error message %s", "formatted") +// assert.NotContainsf(t, {"Hello": "World"}, "Earth", "error message %s", "formatted") func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -532,9 +574,9 @@ func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, a // NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // a slice or a channel with len == 0. // -// if assert.NotEmptyf(t, obj, "error message %s", "formatted") { -// assert.Equal(t, "two", obj[1]) -// } +// if assert.NotEmptyf(t, obj, "error message %s", "formatted") { +// assert.Equal(t, "two", obj[1]) +// } func NotEmptyf(t TestingT, object interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -544,7 +586,7 @@ func NotEmptyf(t TestingT, object interface{}, msg string, args ...interface{}) // NotEqualf asserts that the specified values are NOT equal. // -// assert.NotEqualf(t, obj1, obj2, "error message %s", "formatted") +// assert.NotEqualf(t, obj1, obj2, "error message %s", "formatted") // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). @@ -557,7 +599,7 @@ func NotEqualf(t TestingT, expected interface{}, actual interface{}, msg string, // NotEqualValuesf asserts that two objects are not equal even when converted to the same type // -// assert.NotEqualValuesf(t, obj1, obj2, "error message %s", "formatted") +// assert.NotEqualValuesf(t, obj1, obj2, "error message %s", "formatted") func NotEqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -576,7 +618,7 @@ func NotErrorIsf(t TestingT, err error, target error, msg string, args ...interf // NotNilf asserts that the specified object is not nil. // -// assert.NotNilf(t, err, "error message %s", "formatted") +// assert.NotNilf(t, err, "error message %s", "formatted") func NotNilf(t TestingT, object interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -586,7 +628,7 @@ func NotNilf(t TestingT, object interface{}, msg string, args ...interface{}) bo // NotPanicsf asserts that the code inside the specified PanicTestFunc does NOT panic. // -// assert.NotPanicsf(t, func(){ RemainCalm() }, "error message %s", "formatted") +// assert.NotPanicsf(t, func(){ RemainCalm() }, "error message %s", "formatted") func NotPanicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -596,8 +638,8 @@ func NotPanicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bo // NotRegexpf asserts that a specified regexp does not match a string. // -// assert.NotRegexpf(t, regexp.MustCompile("starts"), "it's starting", "error message %s", "formatted") -// assert.NotRegexpf(t, "^start", "it's not starting", "error message %s", "formatted") +// assert.NotRegexpf(t, regexp.MustCompile("starts"), "it's starting", "error message %s", "formatted") +// assert.NotRegexpf(t, "^start", "it's not starting", "error message %s", "formatted") func NotRegexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -607,7 +649,7 @@ func NotRegexpf(t TestingT, rx interface{}, str interface{}, msg string, args .. // NotSamef asserts that two pointers do not reference the same object. // -// assert.NotSamef(t, ptr1, ptr2, "error message %s", "formatted") +// assert.NotSamef(t, ptr1, ptr2, "error message %s", "formatted") // // Both arguments must be pointer variables. Pointer variable sameness is // determined based on the equality of both type and value. @@ -621,7 +663,7 @@ func NotSamef(t TestingT, expected interface{}, actual interface{}, msg string, // NotSubsetf asserts that the specified list(array, slice...) contains not all // elements given in the specified subset(array, slice...). // -// assert.NotSubsetf(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted") +// assert.NotSubsetf(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted") func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -639,7 +681,7 @@ func NotZerof(t TestingT, i interface{}, msg string, args ...interface{}) bool { // Panicsf asserts that the code inside the specified PanicTestFunc panics. // -// assert.Panicsf(t, func(){ GoCrazy() }, "error message %s", "formatted") +// assert.Panicsf(t, func(){ GoCrazy() }, "error message %s", "formatted") func Panicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -651,7 +693,7 @@ func Panicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bool // panics, and that the recovered panic value is an error that satisfies the // EqualError comparison. // -// assert.PanicsWithErrorf(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted") +// assert.PanicsWithErrorf(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted") func PanicsWithErrorf(t TestingT, errString string, f PanicTestFunc, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -662,7 +704,7 @@ func PanicsWithErrorf(t TestingT, errString string, f PanicTestFunc, msg string, // PanicsWithValuef asserts that the code inside the specified PanicTestFunc panics, and that // the recovered panic value equals the expected panic value. // -// assert.PanicsWithValuef(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted") +// assert.PanicsWithValuef(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted") func PanicsWithValuef(t TestingT, expected interface{}, f PanicTestFunc, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -672,8 +714,8 @@ func PanicsWithValuef(t TestingT, expected interface{}, f PanicTestFunc, msg str // Positivef asserts that the specified element is positive // -// assert.Positivef(t, 1, "error message %s", "formatted") -// assert.Positivef(t, 1.23, "error message %s", "formatted") +// assert.Positivef(t, 1, "error message %s", "formatted") +// assert.Positivef(t, 1.23, "error message %s", "formatted") func Positivef(t TestingT, e interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -683,8 +725,8 @@ func Positivef(t TestingT, e interface{}, msg string, args ...interface{}) bool // Regexpf asserts that a specified regexp matches a string. // -// assert.Regexpf(t, regexp.MustCompile("start"), "it's starting", "error message %s", "formatted") -// assert.Regexpf(t, "start...$", "it's not starting", "error message %s", "formatted") +// assert.Regexpf(t, regexp.MustCompile("start"), "it's starting", "error message %s", "formatted") +// assert.Regexpf(t, "start...$", "it's not starting", "error message %s", "formatted") func Regexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -694,7 +736,7 @@ func Regexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...in // Samef asserts that two pointers reference the same object. // -// assert.Samef(t, ptr1, ptr2, "error message %s", "formatted") +// assert.Samef(t, ptr1, ptr2, "error message %s", "formatted") // // Both arguments must be pointer variables. Pointer variable sameness is // determined based on the equality of both type and value. @@ -708,7 +750,7 @@ func Samef(t TestingT, expected interface{}, actual interface{}, msg string, arg // Subsetf asserts that the specified list(array, slice...) contains all // elements given in the specified subset(array, slice...). // -// assert.Subsetf(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted") +// assert.Subsetf(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted") func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -718,7 +760,7 @@ func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args // Truef asserts that the specified value is true. // -// assert.Truef(t, myBool, "error message %s", "formatted") +// assert.Truef(t, myBool, "error message %s", "formatted") func Truef(t TestingT, value bool, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -728,7 +770,7 @@ func Truef(t TestingT, value bool, msg string, args ...interface{}) bool { // WithinDurationf asserts that the two times are within duration delta of each other. // -// assert.WithinDurationf(t, time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted") +// assert.WithinDurationf(t, time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted") func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -738,7 +780,7 @@ func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta tim // WithinRangef asserts that a time is within a time range (inclusive). // -// assert.WithinRangef(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted") +// assert.WithinRangef(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted") func WithinRangef(t TestingT, actual time.Time, start time.Time, end time.Time, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() diff --git a/vendor/github.com/stretchr/testify/assert/assertion_forward.go b/vendor/github.com/stretchr/testify/assert/assertion_forward.go index 339515b8bf..b1d94aec53 100644 --- a/vendor/github.com/stretchr/testify/assert/assertion_forward.go +++ b/vendor/github.com/stretchr/testify/assert/assertion_forward.go @@ -30,9 +30,9 @@ func (a *Assertions) Conditionf(comp Comparison, msg string, args ...interface{} // Contains asserts that the specified string, list(array, slice...) or map contains the // specified substring or element. // -// a.Contains("Hello World", "World") -// a.Contains(["Hello", "World"], "World") -// a.Contains({"Hello": "World"}, "Hello") +// a.Contains("Hello World", "World") +// a.Contains(["Hello", "World"], "World") +// a.Contains({"Hello": "World"}, "Hello") func (a *Assertions) Contains(s interface{}, contains interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -43,9 +43,9 @@ func (a *Assertions) Contains(s interface{}, contains interface{}, msgAndArgs .. // Containsf asserts that the specified string, list(array, slice...) or map contains the // specified substring or element. // -// a.Containsf("Hello World", "World", "error message %s", "formatted") -// a.Containsf(["Hello", "World"], "World", "error message %s", "formatted") -// a.Containsf({"Hello": "World"}, "Hello", "error message %s", "formatted") +// a.Containsf("Hello World", "World", "error message %s", "formatted") +// a.Containsf(["Hello", "World"], "World", "error message %s", "formatted") +// a.Containsf({"Hello": "World"}, "Hello", "error message %s", "formatted") func (a *Assertions) Containsf(s interface{}, contains interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -98,7 +98,7 @@ func (a *Assertions) ElementsMatchf(listA interface{}, listB interface{}, msg st // Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either // a slice or a channel with len == 0. // -// a.Empty(obj) +// a.Empty(obj) func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -109,7 +109,7 @@ func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) bool { // Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either // a slice or a channel with len == 0. // -// a.Emptyf(obj, "error message %s", "formatted") +// a.Emptyf(obj, "error message %s", "formatted") func (a *Assertions) Emptyf(object interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -119,7 +119,7 @@ func (a *Assertions) Emptyf(object interface{}, msg string, args ...interface{}) // Equal asserts that two objects are equal. // -// a.Equal(123, 123) +// a.Equal(123, 123) // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). Function equality @@ -134,8 +134,8 @@ func (a *Assertions) Equal(expected interface{}, actual interface{}, msgAndArgs // EqualError asserts that a function returned an error (i.e. not `nil`) // and that it is equal to the provided error. // -// actualObj, err := SomeFunction() -// a.EqualError(err, expectedErrorString) +// actualObj, err := SomeFunction() +// a.EqualError(err, expectedErrorString) func (a *Assertions) EqualError(theError error, errString string, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -146,8 +146,8 @@ func (a *Assertions) EqualError(theError error, errString string, msgAndArgs ... // EqualErrorf asserts that a function returned an error (i.e. not `nil`) // and that it is equal to the provided error. // -// actualObj, err := SomeFunction() -// a.EqualErrorf(err, expectedErrorString, "error message %s", "formatted") +// actualObj, err := SomeFunction() +// a.EqualErrorf(err, expectedErrorString, "error message %s", "formatted") func (a *Assertions) EqualErrorf(theError error, errString string, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -155,10 +155,44 @@ func (a *Assertions) EqualErrorf(theError error, errString string, msg string, a return EqualErrorf(a.t, theError, errString, msg, args...) } +// EqualExportedValues asserts that the types of two objects are equal and their public +// fields are also equal. This is useful for comparing structs that have private fields +// that could potentially differ. +// +// type S struct { +// Exported int +// notExported int +// } +// a.EqualExportedValues(S{1, 2}, S{1, 3}) => true +// a.EqualExportedValues(S{1, 2}, S{2, 3}) => false +func (a *Assertions) EqualExportedValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return EqualExportedValues(a.t, expected, actual, msgAndArgs...) +} + +// EqualExportedValuesf asserts that the types of two objects are equal and their public +// fields are also equal. This is useful for comparing structs that have private fields +// that could potentially differ. +// +// type S struct { +// Exported int +// notExported int +// } +// a.EqualExportedValuesf(S{1, 2}, S{1, 3}, "error message %s", "formatted") => true +// a.EqualExportedValuesf(S{1, 2}, S{2, 3}, "error message %s", "formatted") => false +func (a *Assertions) EqualExportedValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return EqualExportedValuesf(a.t, expected, actual, msg, args...) +} + // EqualValues asserts that two objects are equal or convertable to the same types // and equal. // -// a.EqualValues(uint32(123), int32(123)) +// a.EqualValues(uint32(123), int32(123)) func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -169,7 +203,7 @@ func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAn // EqualValuesf asserts that two objects are equal or convertable to the same types // and equal. // -// a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted") +// a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted") func (a *Assertions) EqualValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -179,7 +213,7 @@ func (a *Assertions) EqualValuesf(expected interface{}, actual interface{}, msg // Equalf asserts that two objects are equal. // -// a.Equalf(123, 123, "error message %s", "formatted") +// a.Equalf(123, 123, "error message %s", "formatted") // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). Function equality @@ -193,10 +227,10 @@ func (a *Assertions) Equalf(expected interface{}, actual interface{}, msg string // Error asserts that a function returned an error (i.e. not `nil`). // -// actualObj, err := SomeFunction() -// if a.Error(err) { -// assert.Equal(t, expectedError, err) -// } +// actualObj, err := SomeFunction() +// if a.Error(err) { +// assert.Equal(t, expectedError, err) +// } func (a *Assertions) Error(err error, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -225,8 +259,8 @@ func (a *Assertions) ErrorAsf(err error, target interface{}, msg string, args .. // ErrorContains asserts that a function returned an error (i.e. not `nil`) // and that the error contains the specified substring. // -// actualObj, err := SomeFunction() -// a.ErrorContains(err, expectedErrorSubString) +// actualObj, err := SomeFunction() +// a.ErrorContains(err, expectedErrorSubString) func (a *Assertions) ErrorContains(theError error, contains string, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -237,8 +271,8 @@ func (a *Assertions) ErrorContains(theError error, contains string, msgAndArgs . // ErrorContainsf asserts that a function returned an error (i.e. not `nil`) // and that the error contains the specified substring. // -// actualObj, err := SomeFunction() -// a.ErrorContainsf(err, expectedErrorSubString, "error message %s", "formatted") +// actualObj, err := SomeFunction() +// a.ErrorContainsf(err, expectedErrorSubString, "error message %s", "formatted") func (a *Assertions) ErrorContainsf(theError error, contains string, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -266,10 +300,10 @@ func (a *Assertions) ErrorIsf(err error, target error, msg string, args ...inter // Errorf asserts that a function returned an error (i.e. not `nil`). // -// actualObj, err := SomeFunction() -// if a.Errorf(err, "error message %s", "formatted") { -// assert.Equal(t, expectedErrorf, err) -// } +// actualObj, err := SomeFunction() +// if a.Errorf(err, "error message %s", "formatted") { +// assert.Equal(t, expectedErrorf, err) +// } func (a *Assertions) Errorf(err error, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -280,7 +314,7 @@ func (a *Assertions) Errorf(err error, msg string, args ...interface{}) bool { // Eventually asserts that given condition will be met in waitFor time, // periodically checking target function each tick. // -// a.Eventually(func() bool { return true; }, time.Second, 10*time.Millisecond) +// a.Eventually(func() bool { return true; }, time.Second, 10*time.Millisecond) func (a *Assertions) Eventually(condition func() bool, waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -288,10 +322,60 @@ func (a *Assertions) Eventually(condition func() bool, waitFor time.Duration, ti return Eventually(a.t, condition, waitFor, tick, msgAndArgs...) } +// EventuallyWithT asserts that given condition will be met in waitFor time, +// periodically checking target function each tick. In contrast to Eventually, +// it supplies a CollectT to the condition function, so that the condition +// function can use the CollectT to call other assertions. +// The condition is considered "met" if no errors are raised in a tick. +// The supplied CollectT collects all errors from one tick (if there are any). +// If the condition is not met before waitFor, the collected errors of +// the last tick are copied to t. +// +// externalValue := false +// go func() { +// time.Sleep(8*time.Second) +// externalValue = true +// }() +// a.EventuallyWithT(func(c *assert.CollectT) { +// // add assertions as needed; any assertion failure will fail the current tick +// assert.True(c, externalValue, "expected 'externalValue' to be true") +// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") +func (a *Assertions) EventuallyWithT(condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return EventuallyWithT(a.t, condition, waitFor, tick, msgAndArgs...) +} + +// EventuallyWithTf asserts that given condition will be met in waitFor time, +// periodically checking target function each tick. In contrast to Eventually, +// it supplies a CollectT to the condition function, so that the condition +// function can use the CollectT to call other assertions. +// The condition is considered "met" if no errors are raised in a tick. +// The supplied CollectT collects all errors from one tick (if there are any). +// If the condition is not met before waitFor, the collected errors of +// the last tick are copied to t. +// +// externalValue := false +// go func() { +// time.Sleep(8*time.Second) +// externalValue = true +// }() +// a.EventuallyWithTf(func(c *assert.CollectT, "error message %s", "formatted") { +// // add assertions as needed; any assertion failure will fail the current tick +// assert.True(c, externalValue, "expected 'externalValue' to be true") +// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") +func (a *Assertions) EventuallyWithTf(condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return EventuallyWithTf(a.t, condition, waitFor, tick, msg, args...) +} + // Eventuallyf asserts that given condition will be met in waitFor time, // periodically checking target function each tick. // -// a.Eventuallyf(func() bool { return true; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") +// a.Eventuallyf(func() bool { return true; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") func (a *Assertions) Eventuallyf(condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -301,7 +385,7 @@ func (a *Assertions) Eventuallyf(condition func() bool, waitFor time.Duration, t // Exactly asserts that two objects are equal in value and type. // -// a.Exactly(int32(123), int64(123)) +// a.Exactly(int32(123), int64(123)) func (a *Assertions) Exactly(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -311,7 +395,7 @@ func (a *Assertions) Exactly(expected interface{}, actual interface{}, msgAndArg // Exactlyf asserts that two objects are equal in value and type. // -// a.Exactlyf(int32(123), int64(123), "error message %s", "formatted") +// a.Exactlyf(int32(123), int64(123), "error message %s", "formatted") func (a *Assertions) Exactlyf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -353,7 +437,7 @@ func (a *Assertions) Failf(failureMessage string, msg string, args ...interface{ // False asserts that the specified value is false. // -// a.False(myBool) +// a.False(myBool) func (a *Assertions) False(value bool, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -363,7 +447,7 @@ func (a *Assertions) False(value bool, msgAndArgs ...interface{}) bool { // Falsef asserts that the specified value is false. // -// a.Falsef(myBool, "error message %s", "formatted") +// a.Falsef(myBool, "error message %s", "formatted") func (a *Assertions) Falsef(value bool, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -391,9 +475,9 @@ func (a *Assertions) FileExistsf(path string, msg string, args ...interface{}) b // Greater asserts that the first element is greater than the second // -// a.Greater(2, 1) -// a.Greater(float64(2), float64(1)) -// a.Greater("b", "a") +// a.Greater(2, 1) +// a.Greater(float64(2), float64(1)) +// a.Greater("b", "a") func (a *Assertions) Greater(e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -403,10 +487,10 @@ func (a *Assertions) Greater(e1 interface{}, e2 interface{}, msgAndArgs ...inter // GreaterOrEqual asserts that the first element is greater than or equal to the second // -// a.GreaterOrEqual(2, 1) -// a.GreaterOrEqual(2, 2) -// a.GreaterOrEqual("b", "a") -// a.GreaterOrEqual("b", "b") +// a.GreaterOrEqual(2, 1) +// a.GreaterOrEqual(2, 2) +// a.GreaterOrEqual("b", "a") +// a.GreaterOrEqual("b", "b") func (a *Assertions) GreaterOrEqual(e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -416,10 +500,10 @@ func (a *Assertions) GreaterOrEqual(e1 interface{}, e2 interface{}, msgAndArgs . // GreaterOrEqualf asserts that the first element is greater than or equal to the second // -// a.GreaterOrEqualf(2, 1, "error message %s", "formatted") -// a.GreaterOrEqualf(2, 2, "error message %s", "formatted") -// a.GreaterOrEqualf("b", "a", "error message %s", "formatted") -// a.GreaterOrEqualf("b", "b", "error message %s", "formatted") +// a.GreaterOrEqualf(2, 1, "error message %s", "formatted") +// a.GreaterOrEqualf(2, 2, "error message %s", "formatted") +// a.GreaterOrEqualf("b", "a", "error message %s", "formatted") +// a.GreaterOrEqualf("b", "b", "error message %s", "formatted") func (a *Assertions) GreaterOrEqualf(e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -429,9 +513,9 @@ func (a *Assertions) GreaterOrEqualf(e1 interface{}, e2 interface{}, msg string, // Greaterf asserts that the first element is greater than the second // -// a.Greaterf(2, 1, "error message %s", "formatted") -// a.Greaterf(float64(2), float64(1), "error message %s", "formatted") -// a.Greaterf("b", "a", "error message %s", "formatted") +// a.Greaterf(2, 1, "error message %s", "formatted") +// a.Greaterf(float64(2), float64(1), "error message %s", "formatted") +// a.Greaterf("b", "a", "error message %s", "formatted") func (a *Assertions) Greaterf(e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -442,7 +526,7 @@ func (a *Assertions) Greaterf(e1 interface{}, e2 interface{}, msg string, args . // HTTPBodyContains asserts that a specified handler returns a // body that contains a string. // -// a.HTTPBodyContains(myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky") +// a.HTTPBodyContains(myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPBodyContains(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) bool { @@ -455,7 +539,7 @@ func (a *Assertions) HTTPBodyContains(handler http.HandlerFunc, method string, u // HTTPBodyContainsf asserts that a specified handler returns a // body that contains a string. // -// a.HTTPBodyContainsf(myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") +// a.HTTPBodyContainsf(myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPBodyContainsf(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) bool { @@ -468,7 +552,7 @@ func (a *Assertions) HTTPBodyContainsf(handler http.HandlerFunc, method string, // HTTPBodyNotContains asserts that a specified handler returns a // body that does not contain a string. // -// a.HTTPBodyNotContains(myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky") +// a.HTTPBodyNotContains(myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPBodyNotContains(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) bool { @@ -481,7 +565,7 @@ func (a *Assertions) HTTPBodyNotContains(handler http.HandlerFunc, method string // HTTPBodyNotContainsf asserts that a specified handler returns a // body that does not contain a string. // -// a.HTTPBodyNotContainsf(myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") +// a.HTTPBodyNotContainsf(myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPBodyNotContainsf(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) bool { @@ -493,7 +577,7 @@ func (a *Assertions) HTTPBodyNotContainsf(handler http.HandlerFunc, method strin // HTTPError asserts that a specified handler returns an error status code. // -// a.HTTPError(myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// a.HTTPError(myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPError(handler http.HandlerFunc, method string, url string, values url.Values, msgAndArgs ...interface{}) bool { @@ -505,7 +589,7 @@ func (a *Assertions) HTTPError(handler http.HandlerFunc, method string, url stri // HTTPErrorf asserts that a specified handler returns an error status code. // -// a.HTTPErrorf(myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// a.HTTPErrorf(myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPErrorf(handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool { @@ -517,7 +601,7 @@ func (a *Assertions) HTTPErrorf(handler http.HandlerFunc, method string, url str // HTTPRedirect asserts that a specified handler returns a redirect status code. // -// a.HTTPRedirect(myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// a.HTTPRedirect(myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPRedirect(handler http.HandlerFunc, method string, url string, values url.Values, msgAndArgs ...interface{}) bool { @@ -529,7 +613,7 @@ func (a *Assertions) HTTPRedirect(handler http.HandlerFunc, method string, url s // HTTPRedirectf asserts that a specified handler returns a redirect status code. // -// a.HTTPRedirectf(myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// a.HTTPRedirectf(myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPRedirectf(handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool { @@ -541,7 +625,7 @@ func (a *Assertions) HTTPRedirectf(handler http.HandlerFunc, method string, url // HTTPStatusCode asserts that a specified handler returns a specified status code. // -// a.HTTPStatusCode(myHandler, "GET", "/notImplemented", nil, 501) +// a.HTTPStatusCode(myHandler, "GET", "/notImplemented", nil, 501) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPStatusCode(handler http.HandlerFunc, method string, url string, values url.Values, statuscode int, msgAndArgs ...interface{}) bool { @@ -553,7 +637,7 @@ func (a *Assertions) HTTPStatusCode(handler http.HandlerFunc, method string, url // HTTPStatusCodef asserts that a specified handler returns a specified status code. // -// a.HTTPStatusCodef(myHandler, "GET", "/notImplemented", nil, 501, "error message %s", "formatted") +// a.HTTPStatusCodef(myHandler, "GET", "/notImplemented", nil, 501, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPStatusCodef(handler http.HandlerFunc, method string, url string, values url.Values, statuscode int, msg string, args ...interface{}) bool { @@ -565,7 +649,7 @@ func (a *Assertions) HTTPStatusCodef(handler http.HandlerFunc, method string, ur // HTTPSuccess asserts that a specified handler returns a success status code. // -// a.HTTPSuccess(myHandler, "POST", "http://www.google.com", nil) +// a.HTTPSuccess(myHandler, "POST", "http://www.google.com", nil) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPSuccess(handler http.HandlerFunc, method string, url string, values url.Values, msgAndArgs ...interface{}) bool { @@ -577,7 +661,7 @@ func (a *Assertions) HTTPSuccess(handler http.HandlerFunc, method string, url st // HTTPSuccessf asserts that a specified handler returns a success status code. // -// a.HTTPSuccessf(myHandler, "POST", "http://www.google.com", nil, "error message %s", "formatted") +// a.HTTPSuccessf(myHandler, "POST", "http://www.google.com", nil, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPSuccessf(handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool { @@ -589,7 +673,7 @@ func (a *Assertions) HTTPSuccessf(handler http.HandlerFunc, method string, url s // Implements asserts that an object is implemented by the specified interface. // -// a.Implements((*MyInterface)(nil), new(MyObject)) +// a.Implements((*MyInterface)(nil), new(MyObject)) func (a *Assertions) Implements(interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -599,7 +683,7 @@ func (a *Assertions) Implements(interfaceObject interface{}, object interface{}, // Implementsf asserts that an object is implemented by the specified interface. // -// a.Implementsf((*MyInterface)(nil), new(MyObject), "error message %s", "formatted") +// a.Implementsf((*MyInterface)(nil), new(MyObject), "error message %s", "formatted") func (a *Assertions) Implementsf(interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -609,7 +693,7 @@ func (a *Assertions) Implementsf(interfaceObject interface{}, object interface{} // InDelta asserts that the two numerals are within delta of each other. // -// a.InDelta(math.Pi, 22/7.0, 0.01) +// a.InDelta(math.Pi, 22/7.0, 0.01) func (a *Assertions) InDelta(expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -651,7 +735,7 @@ func (a *Assertions) InDeltaSlicef(expected interface{}, actual interface{}, del // InDeltaf asserts that the two numerals are within delta of each other. // -// a.InDeltaf(math.Pi, 22/7.0, 0.01, "error message %s", "formatted") +// a.InDeltaf(math.Pi, 22/7.0, 0.01, "error message %s", "formatted") func (a *Assertions) InDeltaf(expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -693,9 +777,9 @@ func (a *Assertions) InEpsilonf(expected interface{}, actual interface{}, epsilo // IsDecreasing asserts that the collection is decreasing // -// a.IsDecreasing([]int{2, 1, 0}) -// a.IsDecreasing([]float{2, 1}) -// a.IsDecreasing([]string{"b", "a"}) +// a.IsDecreasing([]int{2, 1, 0}) +// a.IsDecreasing([]float{2, 1}) +// a.IsDecreasing([]string{"b", "a"}) func (a *Assertions) IsDecreasing(object interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -705,9 +789,9 @@ func (a *Assertions) IsDecreasing(object interface{}, msgAndArgs ...interface{}) // IsDecreasingf asserts that the collection is decreasing // -// a.IsDecreasingf([]int{2, 1, 0}, "error message %s", "formatted") -// a.IsDecreasingf([]float{2, 1}, "error message %s", "formatted") -// a.IsDecreasingf([]string{"b", "a"}, "error message %s", "formatted") +// a.IsDecreasingf([]int{2, 1, 0}, "error message %s", "formatted") +// a.IsDecreasingf([]float{2, 1}, "error message %s", "formatted") +// a.IsDecreasingf([]string{"b", "a"}, "error message %s", "formatted") func (a *Assertions) IsDecreasingf(object interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -717,9 +801,9 @@ func (a *Assertions) IsDecreasingf(object interface{}, msg string, args ...inter // IsIncreasing asserts that the collection is increasing // -// a.IsIncreasing([]int{1, 2, 3}) -// a.IsIncreasing([]float{1, 2}) -// a.IsIncreasing([]string{"a", "b"}) +// a.IsIncreasing([]int{1, 2, 3}) +// a.IsIncreasing([]float{1, 2}) +// a.IsIncreasing([]string{"a", "b"}) func (a *Assertions) IsIncreasing(object interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -729,9 +813,9 @@ func (a *Assertions) IsIncreasing(object interface{}, msgAndArgs ...interface{}) // IsIncreasingf asserts that the collection is increasing // -// a.IsIncreasingf([]int{1, 2, 3}, "error message %s", "formatted") -// a.IsIncreasingf([]float{1, 2}, "error message %s", "formatted") -// a.IsIncreasingf([]string{"a", "b"}, "error message %s", "formatted") +// a.IsIncreasingf([]int{1, 2, 3}, "error message %s", "formatted") +// a.IsIncreasingf([]float{1, 2}, "error message %s", "formatted") +// a.IsIncreasingf([]string{"a", "b"}, "error message %s", "formatted") func (a *Assertions) IsIncreasingf(object interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -741,9 +825,9 @@ func (a *Assertions) IsIncreasingf(object interface{}, msg string, args ...inter // IsNonDecreasing asserts that the collection is not decreasing // -// a.IsNonDecreasing([]int{1, 1, 2}) -// a.IsNonDecreasing([]float{1, 2}) -// a.IsNonDecreasing([]string{"a", "b"}) +// a.IsNonDecreasing([]int{1, 1, 2}) +// a.IsNonDecreasing([]float{1, 2}) +// a.IsNonDecreasing([]string{"a", "b"}) func (a *Assertions) IsNonDecreasing(object interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -753,9 +837,9 @@ func (a *Assertions) IsNonDecreasing(object interface{}, msgAndArgs ...interface // IsNonDecreasingf asserts that the collection is not decreasing // -// a.IsNonDecreasingf([]int{1, 1, 2}, "error message %s", "formatted") -// a.IsNonDecreasingf([]float{1, 2}, "error message %s", "formatted") -// a.IsNonDecreasingf([]string{"a", "b"}, "error message %s", "formatted") +// a.IsNonDecreasingf([]int{1, 1, 2}, "error message %s", "formatted") +// a.IsNonDecreasingf([]float{1, 2}, "error message %s", "formatted") +// a.IsNonDecreasingf([]string{"a", "b"}, "error message %s", "formatted") func (a *Assertions) IsNonDecreasingf(object interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -765,9 +849,9 @@ func (a *Assertions) IsNonDecreasingf(object interface{}, msg string, args ...in // IsNonIncreasing asserts that the collection is not increasing // -// a.IsNonIncreasing([]int{2, 1, 1}) -// a.IsNonIncreasing([]float{2, 1}) -// a.IsNonIncreasing([]string{"b", "a"}) +// a.IsNonIncreasing([]int{2, 1, 1}) +// a.IsNonIncreasing([]float{2, 1}) +// a.IsNonIncreasing([]string{"b", "a"}) func (a *Assertions) IsNonIncreasing(object interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -777,9 +861,9 @@ func (a *Assertions) IsNonIncreasing(object interface{}, msgAndArgs ...interface // IsNonIncreasingf asserts that the collection is not increasing // -// a.IsNonIncreasingf([]int{2, 1, 1}, "error message %s", "formatted") -// a.IsNonIncreasingf([]float{2, 1}, "error message %s", "formatted") -// a.IsNonIncreasingf([]string{"b", "a"}, "error message %s", "formatted") +// a.IsNonIncreasingf([]int{2, 1, 1}, "error message %s", "formatted") +// a.IsNonIncreasingf([]float{2, 1}, "error message %s", "formatted") +// a.IsNonIncreasingf([]string{"b", "a"}, "error message %s", "formatted") func (a *Assertions) IsNonIncreasingf(object interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -805,7 +889,7 @@ func (a *Assertions) IsTypef(expectedType interface{}, object interface{}, msg s // JSONEq asserts that two JSON strings are equivalent. // -// a.JSONEq(`{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`) +// a.JSONEq(`{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`) func (a *Assertions) JSONEq(expected string, actual string, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -815,7 +899,7 @@ func (a *Assertions) JSONEq(expected string, actual string, msgAndArgs ...interf // JSONEqf asserts that two JSON strings are equivalent. // -// a.JSONEqf(`{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted") +// a.JSONEqf(`{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted") func (a *Assertions) JSONEqf(expected string, actual string, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -826,7 +910,7 @@ func (a *Assertions) JSONEqf(expected string, actual string, msg string, args .. // Len asserts that the specified object has specific length. // Len also fails if the object has a type that len() not accept. // -// a.Len(mySlice, 3) +// a.Len(mySlice, 3) func (a *Assertions) Len(object interface{}, length int, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -837,7 +921,7 @@ func (a *Assertions) Len(object interface{}, length int, msgAndArgs ...interface // Lenf asserts that the specified object has specific length. // Lenf also fails if the object has a type that len() not accept. // -// a.Lenf(mySlice, 3, "error message %s", "formatted") +// a.Lenf(mySlice, 3, "error message %s", "formatted") func (a *Assertions) Lenf(object interface{}, length int, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -847,9 +931,9 @@ func (a *Assertions) Lenf(object interface{}, length int, msg string, args ...in // Less asserts that the first element is less than the second // -// a.Less(1, 2) -// a.Less(float64(1), float64(2)) -// a.Less("a", "b") +// a.Less(1, 2) +// a.Less(float64(1), float64(2)) +// a.Less("a", "b") func (a *Assertions) Less(e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -859,10 +943,10 @@ func (a *Assertions) Less(e1 interface{}, e2 interface{}, msgAndArgs ...interfac // LessOrEqual asserts that the first element is less than or equal to the second // -// a.LessOrEqual(1, 2) -// a.LessOrEqual(2, 2) -// a.LessOrEqual("a", "b") -// a.LessOrEqual("b", "b") +// a.LessOrEqual(1, 2) +// a.LessOrEqual(2, 2) +// a.LessOrEqual("a", "b") +// a.LessOrEqual("b", "b") func (a *Assertions) LessOrEqual(e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -872,10 +956,10 @@ func (a *Assertions) LessOrEqual(e1 interface{}, e2 interface{}, msgAndArgs ...i // LessOrEqualf asserts that the first element is less than or equal to the second // -// a.LessOrEqualf(1, 2, "error message %s", "formatted") -// a.LessOrEqualf(2, 2, "error message %s", "formatted") -// a.LessOrEqualf("a", "b", "error message %s", "formatted") -// a.LessOrEqualf("b", "b", "error message %s", "formatted") +// a.LessOrEqualf(1, 2, "error message %s", "formatted") +// a.LessOrEqualf(2, 2, "error message %s", "formatted") +// a.LessOrEqualf("a", "b", "error message %s", "formatted") +// a.LessOrEqualf("b", "b", "error message %s", "formatted") func (a *Assertions) LessOrEqualf(e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -885,9 +969,9 @@ func (a *Assertions) LessOrEqualf(e1 interface{}, e2 interface{}, msg string, ar // Lessf asserts that the first element is less than the second // -// a.Lessf(1, 2, "error message %s", "formatted") -// a.Lessf(float64(1), float64(2), "error message %s", "formatted") -// a.Lessf("a", "b", "error message %s", "formatted") +// a.Lessf(1, 2, "error message %s", "formatted") +// a.Lessf(float64(1), float64(2), "error message %s", "formatted") +// a.Lessf("a", "b", "error message %s", "formatted") func (a *Assertions) Lessf(e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -897,8 +981,8 @@ func (a *Assertions) Lessf(e1 interface{}, e2 interface{}, msg string, args ...i // Negative asserts that the specified element is negative // -// a.Negative(-1) -// a.Negative(-1.23) +// a.Negative(-1) +// a.Negative(-1.23) func (a *Assertions) Negative(e interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -908,8 +992,8 @@ func (a *Assertions) Negative(e interface{}, msgAndArgs ...interface{}) bool { // Negativef asserts that the specified element is negative // -// a.Negativef(-1, "error message %s", "formatted") -// a.Negativef(-1.23, "error message %s", "formatted") +// a.Negativef(-1, "error message %s", "formatted") +// a.Negativef(-1.23, "error message %s", "formatted") func (a *Assertions) Negativef(e interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -920,7 +1004,7 @@ func (a *Assertions) Negativef(e interface{}, msg string, args ...interface{}) b // Never asserts that the given condition doesn't satisfy in waitFor time, // periodically checking the target function each tick. // -// a.Never(func() bool { return false; }, time.Second, 10*time.Millisecond) +// a.Never(func() bool { return false; }, time.Second, 10*time.Millisecond) func (a *Assertions) Never(condition func() bool, waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -931,7 +1015,7 @@ func (a *Assertions) Never(condition func() bool, waitFor time.Duration, tick ti // Neverf asserts that the given condition doesn't satisfy in waitFor time, // periodically checking the target function each tick. // -// a.Neverf(func() bool { return false; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") +// a.Neverf(func() bool { return false; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") func (a *Assertions) Neverf(condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -941,7 +1025,7 @@ func (a *Assertions) Neverf(condition func() bool, waitFor time.Duration, tick t // Nil asserts that the specified object is nil. // -// a.Nil(err) +// a.Nil(err) func (a *Assertions) Nil(object interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -951,7 +1035,7 @@ func (a *Assertions) Nil(object interface{}, msgAndArgs ...interface{}) bool { // Nilf asserts that the specified object is nil. // -// a.Nilf(err, "error message %s", "formatted") +// a.Nilf(err, "error message %s", "formatted") func (a *Assertions) Nilf(object interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -979,10 +1063,10 @@ func (a *Assertions) NoDirExistsf(path string, msg string, args ...interface{}) // NoError asserts that a function returned no error (i.e. `nil`). // -// actualObj, err := SomeFunction() -// if a.NoError(err) { -// assert.Equal(t, expectedObj, actualObj) -// } +// actualObj, err := SomeFunction() +// if a.NoError(err) { +// assert.Equal(t, expectedObj, actualObj) +// } func (a *Assertions) NoError(err error, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -992,10 +1076,10 @@ func (a *Assertions) NoError(err error, msgAndArgs ...interface{}) bool { // NoErrorf asserts that a function returned no error (i.e. `nil`). // -// actualObj, err := SomeFunction() -// if a.NoErrorf(err, "error message %s", "formatted") { -// assert.Equal(t, expectedObj, actualObj) -// } +// actualObj, err := SomeFunction() +// if a.NoErrorf(err, "error message %s", "formatted") { +// assert.Equal(t, expectedObj, actualObj) +// } func (a *Assertions) NoErrorf(err error, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1024,9 +1108,9 @@ func (a *Assertions) NoFileExistsf(path string, msg string, args ...interface{}) // NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the // specified substring or element. // -// a.NotContains("Hello World", "Earth") -// a.NotContains(["Hello", "World"], "Earth") -// a.NotContains({"Hello": "World"}, "Earth") +// a.NotContains("Hello World", "Earth") +// a.NotContains(["Hello", "World"], "Earth") +// a.NotContains({"Hello": "World"}, "Earth") func (a *Assertions) NotContains(s interface{}, contains interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1037,9 +1121,9 @@ func (a *Assertions) NotContains(s interface{}, contains interface{}, msgAndArgs // NotContainsf asserts that the specified string, list(array, slice...) or map does NOT contain the // specified substring or element. // -// a.NotContainsf("Hello World", "Earth", "error message %s", "formatted") -// a.NotContainsf(["Hello", "World"], "Earth", "error message %s", "formatted") -// a.NotContainsf({"Hello": "World"}, "Earth", "error message %s", "formatted") +// a.NotContainsf("Hello World", "Earth", "error message %s", "formatted") +// a.NotContainsf(["Hello", "World"], "Earth", "error message %s", "formatted") +// a.NotContainsf({"Hello": "World"}, "Earth", "error message %s", "formatted") func (a *Assertions) NotContainsf(s interface{}, contains interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1050,9 +1134,9 @@ func (a *Assertions) NotContainsf(s interface{}, contains interface{}, msg strin // NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // a slice or a channel with len == 0. // -// if a.NotEmpty(obj) { -// assert.Equal(t, "two", obj[1]) -// } +// if a.NotEmpty(obj) { +// assert.Equal(t, "two", obj[1]) +// } func (a *Assertions) NotEmpty(object interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1063,9 +1147,9 @@ func (a *Assertions) NotEmpty(object interface{}, msgAndArgs ...interface{}) boo // NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // a slice or a channel with len == 0. // -// if a.NotEmptyf(obj, "error message %s", "formatted") { -// assert.Equal(t, "two", obj[1]) -// } +// if a.NotEmptyf(obj, "error message %s", "formatted") { +// assert.Equal(t, "two", obj[1]) +// } func (a *Assertions) NotEmptyf(object interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1075,7 +1159,7 @@ func (a *Assertions) NotEmptyf(object interface{}, msg string, args ...interface // NotEqual asserts that the specified values are NOT equal. // -// a.NotEqual(obj1, obj2) +// a.NotEqual(obj1, obj2) // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). @@ -1088,7 +1172,7 @@ func (a *Assertions) NotEqual(expected interface{}, actual interface{}, msgAndAr // NotEqualValues asserts that two objects are not equal even when converted to the same type // -// a.NotEqualValues(obj1, obj2) +// a.NotEqualValues(obj1, obj2) func (a *Assertions) NotEqualValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1098,7 +1182,7 @@ func (a *Assertions) NotEqualValues(expected interface{}, actual interface{}, ms // NotEqualValuesf asserts that two objects are not equal even when converted to the same type // -// a.NotEqualValuesf(obj1, obj2, "error message %s", "formatted") +// a.NotEqualValuesf(obj1, obj2, "error message %s", "formatted") func (a *Assertions) NotEqualValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1108,7 +1192,7 @@ func (a *Assertions) NotEqualValuesf(expected interface{}, actual interface{}, m // NotEqualf asserts that the specified values are NOT equal. // -// a.NotEqualf(obj1, obj2, "error message %s", "formatted") +// a.NotEqualf(obj1, obj2, "error message %s", "formatted") // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). @@ -1139,7 +1223,7 @@ func (a *Assertions) NotErrorIsf(err error, target error, msg string, args ...in // NotNil asserts that the specified object is not nil. // -// a.NotNil(err) +// a.NotNil(err) func (a *Assertions) NotNil(object interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1149,7 +1233,7 @@ func (a *Assertions) NotNil(object interface{}, msgAndArgs ...interface{}) bool // NotNilf asserts that the specified object is not nil. // -// a.NotNilf(err, "error message %s", "formatted") +// a.NotNilf(err, "error message %s", "formatted") func (a *Assertions) NotNilf(object interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1159,7 +1243,7 @@ func (a *Assertions) NotNilf(object interface{}, msg string, args ...interface{} // NotPanics asserts that the code inside the specified PanicTestFunc does NOT panic. // -// a.NotPanics(func(){ RemainCalm() }) +// a.NotPanics(func(){ RemainCalm() }) func (a *Assertions) NotPanics(f PanicTestFunc, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1169,7 +1253,7 @@ func (a *Assertions) NotPanics(f PanicTestFunc, msgAndArgs ...interface{}) bool // NotPanicsf asserts that the code inside the specified PanicTestFunc does NOT panic. // -// a.NotPanicsf(func(){ RemainCalm() }, "error message %s", "formatted") +// a.NotPanicsf(func(){ RemainCalm() }, "error message %s", "formatted") func (a *Assertions) NotPanicsf(f PanicTestFunc, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1179,8 +1263,8 @@ func (a *Assertions) NotPanicsf(f PanicTestFunc, msg string, args ...interface{} // NotRegexp asserts that a specified regexp does not match a string. // -// a.NotRegexp(regexp.MustCompile("starts"), "it's starting") -// a.NotRegexp("^start", "it's not starting") +// a.NotRegexp(regexp.MustCompile("starts"), "it's starting") +// a.NotRegexp("^start", "it's not starting") func (a *Assertions) NotRegexp(rx interface{}, str interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1190,8 +1274,8 @@ func (a *Assertions) NotRegexp(rx interface{}, str interface{}, msgAndArgs ...in // NotRegexpf asserts that a specified regexp does not match a string. // -// a.NotRegexpf(regexp.MustCompile("starts"), "it's starting", "error message %s", "formatted") -// a.NotRegexpf("^start", "it's not starting", "error message %s", "formatted") +// a.NotRegexpf(regexp.MustCompile("starts"), "it's starting", "error message %s", "formatted") +// a.NotRegexpf("^start", "it's not starting", "error message %s", "formatted") func (a *Assertions) NotRegexpf(rx interface{}, str interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1201,7 +1285,7 @@ func (a *Assertions) NotRegexpf(rx interface{}, str interface{}, msg string, arg // NotSame asserts that two pointers do not reference the same object. // -// a.NotSame(ptr1, ptr2) +// a.NotSame(ptr1, ptr2) // // Both arguments must be pointer variables. Pointer variable sameness is // determined based on the equality of both type and value. @@ -1214,7 +1298,7 @@ func (a *Assertions) NotSame(expected interface{}, actual interface{}, msgAndArg // NotSamef asserts that two pointers do not reference the same object. // -// a.NotSamef(ptr1, ptr2, "error message %s", "formatted") +// a.NotSamef(ptr1, ptr2, "error message %s", "formatted") // // Both arguments must be pointer variables. Pointer variable sameness is // determined based on the equality of both type and value. @@ -1228,7 +1312,7 @@ func (a *Assertions) NotSamef(expected interface{}, actual interface{}, msg stri // NotSubset asserts that the specified list(array, slice...) contains not all // elements given in the specified subset(array, slice...). // -// a.NotSubset([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]") +// a.NotSubset([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]") func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1239,7 +1323,7 @@ func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs // NotSubsetf asserts that the specified list(array, slice...) contains not all // elements given in the specified subset(array, slice...). // -// a.NotSubsetf([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted") +// a.NotSubsetf([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted") func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1265,7 +1349,7 @@ func (a *Assertions) NotZerof(i interface{}, msg string, args ...interface{}) bo // Panics asserts that the code inside the specified PanicTestFunc panics. // -// a.Panics(func(){ GoCrazy() }) +// a.Panics(func(){ GoCrazy() }) func (a *Assertions) Panics(f PanicTestFunc, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1277,7 +1361,7 @@ func (a *Assertions) Panics(f PanicTestFunc, msgAndArgs ...interface{}) bool { // panics, and that the recovered panic value is an error that satisfies the // EqualError comparison. // -// a.PanicsWithError("crazy error", func(){ GoCrazy() }) +// a.PanicsWithError("crazy error", func(){ GoCrazy() }) func (a *Assertions) PanicsWithError(errString string, f PanicTestFunc, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1289,7 +1373,7 @@ func (a *Assertions) PanicsWithError(errString string, f PanicTestFunc, msgAndAr // panics, and that the recovered panic value is an error that satisfies the // EqualError comparison. // -// a.PanicsWithErrorf("crazy error", func(){ GoCrazy() }, "error message %s", "formatted") +// a.PanicsWithErrorf("crazy error", func(){ GoCrazy() }, "error message %s", "formatted") func (a *Assertions) PanicsWithErrorf(errString string, f PanicTestFunc, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1300,7 +1384,7 @@ func (a *Assertions) PanicsWithErrorf(errString string, f PanicTestFunc, msg str // PanicsWithValue asserts that the code inside the specified PanicTestFunc panics, and that // the recovered panic value equals the expected panic value. // -// a.PanicsWithValue("crazy error", func(){ GoCrazy() }) +// a.PanicsWithValue("crazy error", func(){ GoCrazy() }) func (a *Assertions) PanicsWithValue(expected interface{}, f PanicTestFunc, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1311,7 +1395,7 @@ func (a *Assertions) PanicsWithValue(expected interface{}, f PanicTestFunc, msgA // PanicsWithValuef asserts that the code inside the specified PanicTestFunc panics, and that // the recovered panic value equals the expected panic value. // -// a.PanicsWithValuef("crazy error", func(){ GoCrazy() }, "error message %s", "formatted") +// a.PanicsWithValuef("crazy error", func(){ GoCrazy() }, "error message %s", "formatted") func (a *Assertions) PanicsWithValuef(expected interface{}, f PanicTestFunc, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1321,7 +1405,7 @@ func (a *Assertions) PanicsWithValuef(expected interface{}, f PanicTestFunc, msg // Panicsf asserts that the code inside the specified PanicTestFunc panics. // -// a.Panicsf(func(){ GoCrazy() }, "error message %s", "formatted") +// a.Panicsf(func(){ GoCrazy() }, "error message %s", "formatted") func (a *Assertions) Panicsf(f PanicTestFunc, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1331,8 +1415,8 @@ func (a *Assertions) Panicsf(f PanicTestFunc, msg string, args ...interface{}) b // Positive asserts that the specified element is positive // -// a.Positive(1) -// a.Positive(1.23) +// a.Positive(1) +// a.Positive(1.23) func (a *Assertions) Positive(e interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1342,8 +1426,8 @@ func (a *Assertions) Positive(e interface{}, msgAndArgs ...interface{}) bool { // Positivef asserts that the specified element is positive // -// a.Positivef(1, "error message %s", "formatted") -// a.Positivef(1.23, "error message %s", "formatted") +// a.Positivef(1, "error message %s", "formatted") +// a.Positivef(1.23, "error message %s", "formatted") func (a *Assertions) Positivef(e interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1353,8 +1437,8 @@ func (a *Assertions) Positivef(e interface{}, msg string, args ...interface{}) b // Regexp asserts that a specified regexp matches a string. // -// a.Regexp(regexp.MustCompile("start"), "it's starting") -// a.Regexp("start...$", "it's not starting") +// a.Regexp(regexp.MustCompile("start"), "it's starting") +// a.Regexp("start...$", "it's not starting") func (a *Assertions) Regexp(rx interface{}, str interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1364,8 +1448,8 @@ func (a *Assertions) Regexp(rx interface{}, str interface{}, msgAndArgs ...inter // Regexpf asserts that a specified regexp matches a string. // -// a.Regexpf(regexp.MustCompile("start"), "it's starting", "error message %s", "formatted") -// a.Regexpf("start...$", "it's not starting", "error message %s", "formatted") +// a.Regexpf(regexp.MustCompile("start"), "it's starting", "error message %s", "formatted") +// a.Regexpf("start...$", "it's not starting", "error message %s", "formatted") func (a *Assertions) Regexpf(rx interface{}, str interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1375,7 +1459,7 @@ func (a *Assertions) Regexpf(rx interface{}, str interface{}, msg string, args . // Same asserts that two pointers reference the same object. // -// a.Same(ptr1, ptr2) +// a.Same(ptr1, ptr2) // // Both arguments must be pointer variables. Pointer variable sameness is // determined based on the equality of both type and value. @@ -1388,7 +1472,7 @@ func (a *Assertions) Same(expected interface{}, actual interface{}, msgAndArgs . // Samef asserts that two pointers reference the same object. // -// a.Samef(ptr1, ptr2, "error message %s", "formatted") +// a.Samef(ptr1, ptr2, "error message %s", "formatted") // // Both arguments must be pointer variables. Pointer variable sameness is // determined based on the equality of both type and value. @@ -1402,7 +1486,7 @@ func (a *Assertions) Samef(expected interface{}, actual interface{}, msg string, // Subset asserts that the specified list(array, slice...) contains all // elements given in the specified subset(array, slice...). // -// a.Subset([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]") +// a.Subset([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]") func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1413,7 +1497,7 @@ func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ... // Subsetf asserts that the specified list(array, slice...) contains all // elements given in the specified subset(array, slice...). // -// a.Subsetf([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted") +// a.Subsetf([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted") func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1423,7 +1507,7 @@ func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, a // True asserts that the specified value is true. // -// a.True(myBool) +// a.True(myBool) func (a *Assertions) True(value bool, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1433,7 +1517,7 @@ func (a *Assertions) True(value bool, msgAndArgs ...interface{}) bool { // Truef asserts that the specified value is true. // -// a.Truef(myBool, "error message %s", "formatted") +// a.Truef(myBool, "error message %s", "formatted") func (a *Assertions) Truef(value bool, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1443,7 +1527,7 @@ func (a *Assertions) Truef(value bool, msg string, args ...interface{}) bool { // WithinDuration asserts that the two times are within duration delta of each other. // -// a.WithinDuration(time.Now(), time.Now(), 10*time.Second) +// a.WithinDuration(time.Now(), time.Now(), 10*time.Second) func (a *Assertions) WithinDuration(expected time.Time, actual time.Time, delta time.Duration, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1453,7 +1537,7 @@ func (a *Assertions) WithinDuration(expected time.Time, actual time.Time, delta // WithinDurationf asserts that the two times are within duration delta of each other. // -// a.WithinDurationf(time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted") +// a.WithinDurationf(time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted") func (a *Assertions) WithinDurationf(expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1463,7 +1547,7 @@ func (a *Assertions) WithinDurationf(expected time.Time, actual time.Time, delta // WithinRange asserts that a time is within a time range (inclusive). // -// a.WithinRange(time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second)) +// a.WithinRange(time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second)) func (a *Assertions) WithinRange(actual time.Time, start time.Time, end time.Time, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1473,7 +1557,7 @@ func (a *Assertions) WithinRange(actual time.Time, start time.Time, end time.Tim // WithinRangef asserts that a time is within a time range (inclusive). // -// a.WithinRangef(time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted") +// a.WithinRangef(time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted") func (a *Assertions) WithinRangef(actual time.Time, start time.Time, end time.Time, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() diff --git a/vendor/github.com/stretchr/testify/assert/assertion_order.go b/vendor/github.com/stretchr/testify/assert/assertion_order.go index 7594487835..00df62a059 100644 --- a/vendor/github.com/stretchr/testify/assert/assertion_order.go +++ b/vendor/github.com/stretchr/testify/assert/assertion_order.go @@ -46,36 +46,36 @@ func isOrdered(t TestingT, object interface{}, allowedComparesResults []CompareT // IsIncreasing asserts that the collection is increasing // -// assert.IsIncreasing(t, []int{1, 2, 3}) -// assert.IsIncreasing(t, []float{1, 2}) -// assert.IsIncreasing(t, []string{"a", "b"}) +// assert.IsIncreasing(t, []int{1, 2, 3}) +// assert.IsIncreasing(t, []float{1, 2}) +// assert.IsIncreasing(t, []string{"a", "b"}) func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { return isOrdered(t, object, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...) } // IsNonIncreasing asserts that the collection is not increasing // -// assert.IsNonIncreasing(t, []int{2, 1, 1}) -// assert.IsNonIncreasing(t, []float{2, 1}) -// assert.IsNonIncreasing(t, []string{"b", "a"}) +// assert.IsNonIncreasing(t, []int{2, 1, 1}) +// assert.IsNonIncreasing(t, []float{2, 1}) +// assert.IsNonIncreasing(t, []string{"b", "a"}) func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { return isOrdered(t, object, []CompareType{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...) } // IsDecreasing asserts that the collection is decreasing // -// assert.IsDecreasing(t, []int{2, 1, 0}) -// assert.IsDecreasing(t, []float{2, 1}) -// assert.IsDecreasing(t, []string{"b", "a"}) +// assert.IsDecreasing(t, []int{2, 1, 0}) +// assert.IsDecreasing(t, []float{2, 1}) +// assert.IsDecreasing(t, []string{"b", "a"}) func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { return isOrdered(t, object, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...) } // IsNonDecreasing asserts that the collection is not decreasing // -// assert.IsNonDecreasing(t, []int{1, 1, 2}) -// assert.IsNonDecreasing(t, []float{1, 2}) -// assert.IsNonDecreasing(t, []string{"a", "b"}) +// assert.IsNonDecreasing(t, []int{1, 1, 2}) +// assert.IsNonDecreasing(t, []float{1, 2}) +// assert.IsNonDecreasing(t, []string{"a", "b"}) func IsNonDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { return isOrdered(t, object, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...) } diff --git a/vendor/github.com/stretchr/testify/assert/assertions.go b/vendor/github.com/stretchr/testify/assert/assertions.go index 2924cf3a14..a55d1bba92 100644 --- a/vendor/github.com/stretchr/testify/assert/assertions.go +++ b/vendor/github.com/stretchr/testify/assert/assertions.go @@ -75,6 +75,77 @@ func ObjectsAreEqual(expected, actual interface{}) bool { return bytes.Equal(exp, act) } +// copyExportedFields iterates downward through nested data structures and creates a copy +// that only contains the exported struct fields. +func copyExportedFields(expected interface{}) interface{} { + if isNil(expected) { + return expected + } + + expectedType := reflect.TypeOf(expected) + expectedKind := expectedType.Kind() + expectedValue := reflect.ValueOf(expected) + + switch expectedKind { + case reflect.Struct: + result := reflect.New(expectedType).Elem() + for i := 0; i < expectedType.NumField(); i++ { + field := expectedType.Field(i) + isExported := field.IsExported() + if isExported { + fieldValue := expectedValue.Field(i) + if isNil(fieldValue) || isNil(fieldValue.Interface()) { + continue + } + newValue := copyExportedFields(fieldValue.Interface()) + result.Field(i).Set(reflect.ValueOf(newValue)) + } + } + return result.Interface() + + case reflect.Ptr: + result := reflect.New(expectedType.Elem()) + unexportedRemoved := copyExportedFields(expectedValue.Elem().Interface()) + result.Elem().Set(reflect.ValueOf(unexportedRemoved)) + return result.Interface() + + case reflect.Array, reflect.Slice: + result := reflect.MakeSlice(expectedType, expectedValue.Len(), expectedValue.Len()) + for i := 0; i < expectedValue.Len(); i++ { + index := expectedValue.Index(i) + if isNil(index) { + continue + } + unexportedRemoved := copyExportedFields(index.Interface()) + result.Index(i).Set(reflect.ValueOf(unexportedRemoved)) + } + return result.Interface() + + case reflect.Map: + result := reflect.MakeMap(expectedType) + for _, k := range expectedValue.MapKeys() { + index := expectedValue.MapIndex(k) + unexportedRemoved := copyExportedFields(index.Interface()) + result.SetMapIndex(k, reflect.ValueOf(unexportedRemoved)) + } + return result.Interface() + + default: + return expected + } +} + +// ObjectsExportedFieldsAreEqual determines if the exported (public) fields of two objects are +// considered equal. This comparison of only exported fields is applied recursively to nested data +// structures. +// +// This function does no assertion of any kind. +func ObjectsExportedFieldsAreEqual(expected, actual interface{}) bool { + expectedCleaned := copyExportedFields(expected) + actualCleaned := copyExportedFields(actual) + return ObjectsAreEqualValues(expectedCleaned, actualCleaned) +} + // ObjectsAreEqualValues gets whether two objects are equal, or if their // values are equal. func ObjectsAreEqualValues(expected, actual interface{}) bool { @@ -271,7 +342,7 @@ type labeledContent struct { // labeledOutput returns a string consisting of the provided labeledContent. Each labeled output is appended in the following manner: // -// \t{{label}}:{{align_spaces}}\t{{content}}\n +// \t{{label}}:{{align_spaces}}\t{{content}}\n // // The initial carriage return is required to undo/erase any padding added by testing.T.Errorf. The "\t{{label}}:" is for the label. // If a label is shorter than the longest label provided, padding spaces are added to make all the labels match in length. Once this @@ -294,7 +365,7 @@ func labeledOutput(content ...labeledContent) string { // Implements asserts that an object is implemented by the specified interface. // -// assert.Implements(t, (*MyInterface)(nil), new(MyObject)) +// assert.Implements(t, (*MyInterface)(nil), new(MyObject)) func Implements(t TestingT, interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -326,7 +397,7 @@ func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs // Equal asserts that two objects are equal. // -// assert.Equal(t, 123, 123) +// assert.Equal(t, 123, 123) // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). Function equality @@ -367,7 +438,7 @@ func validateEqualArgs(expected, actual interface{}) error { // Same asserts that two pointers reference the same object. // -// assert.Same(t, ptr1, ptr2) +// assert.Same(t, ptr1, ptr2) // // Both arguments must be pointer variables. Pointer variable sameness is // determined based on the equality of both type and value. @@ -387,7 +458,7 @@ func Same(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) b // NotSame asserts that two pointers do not reference the same object. // -// assert.NotSame(t, ptr1, ptr2) +// assert.NotSame(t, ptr1, ptr2) // // Both arguments must be pointer variables. Pointer variable sameness is // determined based on the equality of both type and value. @@ -455,7 +526,7 @@ func truncatingFormat(data interface{}) string { // EqualValues asserts that two objects are equal or convertable to the same types // and equal. // -// assert.EqualValues(t, uint32(123), int32(123)) +// assert.EqualValues(t, uint32(123), int32(123)) func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -473,9 +544,53 @@ func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interfa } +// EqualExportedValues asserts that the types of two objects are equal and their public +// fields are also equal. This is useful for comparing structs that have private fields +// that could potentially differ. +// +// type S struct { +// Exported int +// notExported int +// } +// assert.EqualExportedValues(t, S{1, 2}, S{1, 3}) => true +// assert.EqualExportedValues(t, S{1, 2}, S{2, 3}) => false +func EqualExportedValues(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + aType := reflect.TypeOf(expected) + bType := reflect.TypeOf(actual) + + if aType != bType { + return Fail(t, fmt.Sprintf("Types expected to match exactly\n\t%v != %v", aType, bType), msgAndArgs...) + } + + if aType.Kind() != reflect.Struct { + return Fail(t, fmt.Sprintf("Types expected to both be struct \n\t%v != %v", aType.Kind(), reflect.Struct), msgAndArgs...) + } + + if bType.Kind() != reflect.Struct { + return Fail(t, fmt.Sprintf("Types expected to both be struct \n\t%v != %v", bType.Kind(), reflect.Struct), msgAndArgs...) + } + + expected = copyExportedFields(expected) + actual = copyExportedFields(actual) + + if !ObjectsAreEqualValues(expected, actual) { + diff := diff(expected, actual) + expected, actual = formatUnequalValues(expected, actual) + return Fail(t, fmt.Sprintf("Not equal (comparing only exported fields): \n"+ + "expected: %s\n"+ + "actual : %s%s", expected, actual, diff), msgAndArgs...) + } + + return true +} + // Exactly asserts that two objects are equal in value and type. // -// assert.Exactly(t, int32(123), int64(123)) +// assert.Exactly(t, int32(123), int64(123)) func Exactly(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -494,7 +609,7 @@ func Exactly(t TestingT, expected, actual interface{}, msgAndArgs ...interface{} // NotNil asserts that the specified object is not nil. // -// assert.NotNil(t, err) +// assert.NotNil(t, err) func NotNil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { if !isNil(object) { return true @@ -540,7 +655,7 @@ func isNil(object interface{}) bool { // Nil asserts that the specified object is nil. // -// assert.Nil(t, err) +// assert.Nil(t, err) func Nil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { if isNil(object) { return true @@ -583,7 +698,7 @@ func isEmpty(object interface{}) bool { // Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either // a slice or a channel with len == 0. // -// assert.Empty(t, obj) +// assert.Empty(t, obj) func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { pass := isEmpty(object) if !pass { @@ -600,9 +715,9 @@ func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { // NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // a slice or a channel with len == 0. // -// if assert.NotEmpty(t, obj) { -// assert.Equal(t, "two", obj[1]) -// } +// if assert.NotEmpty(t, obj) { +// assert.Equal(t, "two", obj[1]) +// } func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { pass := !isEmpty(object) if !pass { @@ -631,7 +746,7 @@ func getLen(x interface{}) (ok bool, length int) { // Len asserts that the specified object has specific length. // Len also fails if the object has a type that len() not accept. // -// assert.Len(t, mySlice, 3) +// assert.Len(t, mySlice, 3) func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -649,7 +764,7 @@ func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{}) // True asserts that the specified value is true. // -// assert.True(t, myBool) +// assert.True(t, myBool) func True(t TestingT, value bool, msgAndArgs ...interface{}) bool { if !value { if h, ok := t.(tHelper); ok { @@ -664,7 +779,7 @@ func True(t TestingT, value bool, msgAndArgs ...interface{}) bool { // False asserts that the specified value is false. // -// assert.False(t, myBool) +// assert.False(t, myBool) func False(t TestingT, value bool, msgAndArgs ...interface{}) bool { if value { if h, ok := t.(tHelper); ok { @@ -679,7 +794,7 @@ func False(t TestingT, value bool, msgAndArgs ...interface{}) bool { // NotEqual asserts that the specified values are NOT equal. // -// assert.NotEqual(t, obj1, obj2) +// assert.NotEqual(t, obj1, obj2) // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). @@ -702,7 +817,7 @@ func NotEqual(t TestingT, expected, actual interface{}, msgAndArgs ...interface{ // NotEqualValues asserts that two objects are not equal even when converted to the same type // -// assert.NotEqualValues(t, obj1, obj2) +// assert.NotEqualValues(t, obj1, obj2) func NotEqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -761,9 +876,9 @@ func containsElement(list interface{}, element interface{}) (ok, found bool) { // Contains asserts that the specified string, list(array, slice...) or map contains the // specified substring or element. // -// assert.Contains(t, "Hello World", "World") -// assert.Contains(t, ["Hello", "World"], "World") -// assert.Contains(t, {"Hello": "World"}, "Hello") +// assert.Contains(t, "Hello World", "World") +// assert.Contains(t, ["Hello", "World"], "World") +// assert.Contains(t, {"Hello": "World"}, "Hello") func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -784,9 +899,9 @@ func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bo // NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the // specified substring or element. // -// assert.NotContains(t, "Hello World", "Earth") -// assert.NotContains(t, ["Hello", "World"], "Earth") -// assert.NotContains(t, {"Hello": "World"}, "Earth") +// assert.NotContains(t, "Hello World", "Earth") +// assert.NotContains(t, ["Hello", "World"], "Earth") +// assert.NotContains(t, {"Hello": "World"}, "Earth") func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -794,10 +909,10 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) ok, found := containsElement(s, contains) if !ok { - return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", s), msgAndArgs...) + return Fail(t, fmt.Sprintf("%#v could not be applied builtin len()", s), msgAndArgs...) } if found { - return Fail(t, fmt.Sprintf("\"%s\" should not contain \"%s\"", s, contains), msgAndArgs...) + return Fail(t, fmt.Sprintf("%#v should not contain %#v", s, contains), msgAndArgs...) } return true @@ -807,7 +922,7 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) // Subset asserts that the specified list(array, slice...) contains all // elements given in the specified subset(array, slice...). // -// assert.Subset(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]") +// assert.Subset(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]") func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) { if h, ok := t.(tHelper); ok { h.Helper() @@ -863,7 +978,7 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok // NotSubset asserts that the specified list(array, slice...) contains not all // elements given in the specified subset(array, slice...). // -// assert.NotSubset(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]") +// assert.NotSubset(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]") func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1048,7 +1163,7 @@ func didPanic(f PanicTestFunc) (didPanic bool, message interface{}, stack string // Panics asserts that the code inside the specified PanicTestFunc panics. // -// assert.Panics(t, func(){ GoCrazy() }) +// assert.Panics(t, func(){ GoCrazy() }) func Panics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -1064,7 +1179,7 @@ func Panics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool { // PanicsWithValue asserts that the code inside the specified PanicTestFunc panics, and that // the recovered panic value equals the expected panic value. // -// assert.PanicsWithValue(t, "crazy error", func(){ GoCrazy() }) +// assert.PanicsWithValue(t, "crazy error", func(){ GoCrazy() }) func PanicsWithValue(t TestingT, expected interface{}, f PanicTestFunc, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -1085,7 +1200,7 @@ func PanicsWithValue(t TestingT, expected interface{}, f PanicTestFunc, msgAndAr // panics, and that the recovered panic value is an error that satisfies the // EqualError comparison. // -// assert.PanicsWithError(t, "crazy error", func(){ GoCrazy() }) +// assert.PanicsWithError(t, "crazy error", func(){ GoCrazy() }) func PanicsWithError(t TestingT, errString string, f PanicTestFunc, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -1105,7 +1220,7 @@ func PanicsWithError(t TestingT, errString string, f PanicTestFunc, msgAndArgs . // NotPanics asserts that the code inside the specified PanicTestFunc does NOT panic. // -// assert.NotPanics(t, func(){ RemainCalm() }) +// assert.NotPanics(t, func(){ RemainCalm() }) func NotPanics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -1120,7 +1235,7 @@ func NotPanics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool { // WithinDuration asserts that the two times are within duration delta of each other. // -// assert.WithinDuration(t, time.Now(), time.Now(), 10*time.Second) +// assert.WithinDuration(t, time.Now(), time.Now(), 10*time.Second) func WithinDuration(t TestingT, expected, actual time.Time, delta time.Duration, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -1136,7 +1251,7 @@ func WithinDuration(t TestingT, expected, actual time.Time, delta time.Duration, // WithinRange asserts that a time is within a time range (inclusive). // -// assert.WithinRange(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second)) +// assert.WithinRange(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second)) func WithinRange(t TestingT, actual, start, end time.Time, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -1195,7 +1310,7 @@ func toFloat(x interface{}) (float64, bool) { // InDelta asserts that the two numerals are within delta of each other. // -// assert.InDelta(t, math.Pi, 22/7.0, 0.01) +// assert.InDelta(t, math.Pi, 22/7.0, 0.01) func InDelta(t TestingT, expected, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -1368,10 +1483,10 @@ func InEpsilonSlice(t TestingT, expected, actual interface{}, epsilon float64, m // NoError asserts that a function returned no error (i.e. `nil`). // -// actualObj, err := SomeFunction() -// if assert.NoError(t, err) { -// assert.Equal(t, expectedObj, actualObj) -// } +// actualObj, err := SomeFunction() +// if assert.NoError(t, err) { +// assert.Equal(t, expectedObj, actualObj) +// } func NoError(t TestingT, err error, msgAndArgs ...interface{}) bool { if err != nil { if h, ok := t.(tHelper); ok { @@ -1385,10 +1500,10 @@ func NoError(t TestingT, err error, msgAndArgs ...interface{}) bool { // Error asserts that a function returned an error (i.e. not `nil`). // -// actualObj, err := SomeFunction() -// if assert.Error(t, err) { -// assert.Equal(t, expectedError, err) -// } +// actualObj, err := SomeFunction() +// if assert.Error(t, err) { +// assert.Equal(t, expectedError, err) +// } func Error(t TestingT, err error, msgAndArgs ...interface{}) bool { if err == nil { if h, ok := t.(tHelper); ok { @@ -1403,8 +1518,8 @@ func Error(t TestingT, err error, msgAndArgs ...interface{}) bool { // EqualError asserts that a function returned an error (i.e. not `nil`) // and that it is equal to the provided error. // -// actualObj, err := SomeFunction() -// assert.EqualError(t, err, expectedErrorString) +// actualObj, err := SomeFunction() +// assert.EqualError(t, err, expectedErrorString) func EqualError(t TestingT, theError error, errString string, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -1426,8 +1541,8 @@ func EqualError(t TestingT, theError error, errString string, msgAndArgs ...inte // ErrorContains asserts that a function returned an error (i.e. not `nil`) // and that the error contains the specified substring. // -// actualObj, err := SomeFunction() -// assert.ErrorContains(t, err, expectedErrorSubString) +// actualObj, err := SomeFunction() +// assert.ErrorContains(t, err, expectedErrorSubString) func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -1460,8 +1575,8 @@ func matchRegexp(rx interface{}, str interface{}) bool { // Regexp asserts that a specified regexp matches a string. // -// assert.Regexp(t, regexp.MustCompile("start"), "it's starting") -// assert.Regexp(t, "start...$", "it's not starting") +// assert.Regexp(t, regexp.MustCompile("start"), "it's starting") +// assert.Regexp(t, "start...$", "it's not starting") func Regexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -1478,8 +1593,8 @@ func Regexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface // NotRegexp asserts that a specified regexp does not match a string. // -// assert.NotRegexp(t, regexp.MustCompile("starts"), "it's starting") -// assert.NotRegexp(t, "^start", "it's not starting") +// assert.NotRegexp(t, regexp.MustCompile("starts"), "it's starting") +// assert.NotRegexp(t, "^start", "it's not starting") func NotRegexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -1591,7 +1706,7 @@ func NoDirExists(t TestingT, path string, msgAndArgs ...interface{}) bool { // JSONEq asserts that two JSON strings are equivalent. // -// assert.JSONEq(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`) +// assert.JSONEq(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`) func JSONEq(t TestingT, expected string, actual string, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -1714,7 +1829,7 @@ type tHelper interface { // Eventually asserts that given condition will be met in waitFor time, // periodically checking target function each tick. // -// assert.Eventually(t, func() bool { return true; }, time.Second, 10*time.Millisecond) +// assert.Eventually(t, func() bool { return true; }, time.Second, 10*time.Millisecond) func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -1744,10 +1859,93 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t } } +// CollectT implements the TestingT interface and collects all errors. +type CollectT struct { + errors []error +} + +// Errorf collects the error. +func (c *CollectT) Errorf(format string, args ...interface{}) { + c.errors = append(c.errors, fmt.Errorf(format, args...)) +} + +// FailNow panics. +func (c *CollectT) FailNow() { + panic("Assertion failed") +} + +// Reset clears the collected errors. +func (c *CollectT) Reset() { + c.errors = nil +} + +// Copy copies the collected errors to the supplied t. +func (c *CollectT) Copy(t TestingT) { + if tt, ok := t.(tHelper); ok { + tt.Helper() + } + for _, err := range c.errors { + t.Errorf("%v", err) + } +} + +// EventuallyWithT asserts that given condition will be met in waitFor time, +// periodically checking target function each tick. In contrast to Eventually, +// it supplies a CollectT to the condition function, so that the condition +// function can use the CollectT to call other assertions. +// The condition is considered "met" if no errors are raised in a tick. +// The supplied CollectT collects all errors from one tick (if there are any). +// If the condition is not met before waitFor, the collected errors of +// the last tick are copied to t. +// +// externalValue := false +// go func() { +// time.Sleep(8*time.Second) +// externalValue = true +// }() +// assert.EventuallyWithT(t, func(c *assert.CollectT) { +// // add assertions as needed; any assertion failure will fail the current tick +// assert.True(c, externalValue, "expected 'externalValue' to be true") +// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") +func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + collect := new(CollectT) + ch := make(chan bool, 1) + + timer := time.NewTimer(waitFor) + defer timer.Stop() + + ticker := time.NewTicker(tick) + defer ticker.Stop() + + for tick := ticker.C; ; { + select { + case <-timer.C: + collect.Copy(t) + return Fail(t, "Condition never satisfied", msgAndArgs...) + case <-tick: + tick = nil + collect.Reset() + go func() { + condition(collect) + ch <- len(collect.errors) == 0 + }() + case v := <-ch: + if v { + return true + } + tick = ticker.C + } + } +} + // Never asserts that the given condition doesn't satisfy in waitFor time, // periodically checking the target function each tick. // -// assert.Never(t, func() bool { return false; }, time.Second, 10*time.Millisecond) +// assert.Never(t, func() bool { return false; }, time.Second, 10*time.Millisecond) func Never(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() diff --git a/vendor/github.com/stretchr/testify/assert/doc.go b/vendor/github.com/stretchr/testify/assert/doc.go index c9dccc4d6c..4953981d38 100644 --- a/vendor/github.com/stretchr/testify/assert/doc.go +++ b/vendor/github.com/stretchr/testify/assert/doc.go @@ -1,39 +1,40 @@ // Package assert provides a set of comprehensive testing tools for use with the normal Go testing system. // -// Example Usage +// # Example Usage // // The following is a complete example using assert in a standard test function: -// import ( -// "testing" -// "github.com/stretchr/testify/assert" -// ) // -// func TestSomething(t *testing.T) { +// import ( +// "testing" +// "github.com/stretchr/testify/assert" +// ) // -// var a string = "Hello" -// var b string = "Hello" +// func TestSomething(t *testing.T) { // -// assert.Equal(t, a, b, "The two words should be the same.") +// var a string = "Hello" +// var b string = "Hello" // -// } +// assert.Equal(t, a, b, "The two words should be the same.") +// +// } // // if you assert many times, use the format below: // -// import ( -// "testing" -// "github.com/stretchr/testify/assert" -// ) +// import ( +// "testing" +// "github.com/stretchr/testify/assert" +// ) // -// func TestSomething(t *testing.T) { -// assert := assert.New(t) +// func TestSomething(t *testing.T) { +// assert := assert.New(t) // -// var a string = "Hello" -// var b string = "Hello" +// var a string = "Hello" +// var b string = "Hello" // -// assert.Equal(a, b, "The two words should be the same.") -// } +// assert.Equal(a, b, "The two words should be the same.") +// } // -// Assertions +// # Assertions // // Assertions allow you to easily write test code, and are global funcs in the `assert` package. // All assertion functions take, as the first argument, the `*testing.T` object provided by the diff --git a/vendor/github.com/stretchr/testify/assert/http_assertions.go b/vendor/github.com/stretchr/testify/assert/http_assertions.go index 4ed341dd28..d8038c28a7 100644 --- a/vendor/github.com/stretchr/testify/assert/http_assertions.go +++ b/vendor/github.com/stretchr/testify/assert/http_assertions.go @@ -23,7 +23,7 @@ func httpCode(handler http.HandlerFunc, method, url string, values url.Values) ( // HTTPSuccess asserts that a specified handler returns a success status code. // -// assert.HTTPSuccess(t, myHandler, "POST", "http://www.google.com", nil) +// assert.HTTPSuccess(t, myHandler, "POST", "http://www.google.com", nil) // // Returns whether the assertion was successful (true) or not (false). func HTTPSuccess(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, msgAndArgs ...interface{}) bool { @@ -45,7 +45,7 @@ func HTTPSuccess(t TestingT, handler http.HandlerFunc, method, url string, value // HTTPRedirect asserts that a specified handler returns a redirect status code. // -// assert.HTTPRedirect(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// assert.HTTPRedirect(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func HTTPRedirect(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, msgAndArgs ...interface{}) bool { @@ -67,7 +67,7 @@ func HTTPRedirect(t TestingT, handler http.HandlerFunc, method, url string, valu // HTTPError asserts that a specified handler returns an error status code. // -// assert.HTTPError(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// assert.HTTPError(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func HTTPError(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, msgAndArgs ...interface{}) bool { @@ -89,7 +89,7 @@ func HTTPError(t TestingT, handler http.HandlerFunc, method, url string, values // HTTPStatusCode asserts that a specified handler returns a specified status code. // -// assert.HTTPStatusCode(t, myHandler, "GET", "/notImplemented", nil, 501) +// assert.HTTPStatusCode(t, myHandler, "GET", "/notImplemented", nil, 501) // // Returns whether the assertion was successful (true) or not (false). func HTTPStatusCode(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, statuscode int, msgAndArgs ...interface{}) bool { @@ -124,7 +124,7 @@ func HTTPBody(handler http.HandlerFunc, method, url string, values url.Values) s // HTTPBodyContains asserts that a specified handler returns a // body that contains a string. // -// assert.HTTPBodyContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky") +// assert.HTTPBodyContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky") // // Returns whether the assertion was successful (true) or not (false). func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) bool { @@ -144,7 +144,7 @@ func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method, url string, // HTTPBodyNotContains asserts that a specified handler returns a // body that does not contain a string. // -// assert.HTTPBodyNotContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky") +// assert.HTTPBodyNotContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky") // // Returns whether the assertion was successful (true) or not (false). func HTTPBodyNotContains(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) bool { diff --git a/vendor/github.com/tailscale/golang-x-crypto/acme/acme.go b/vendor/github.com/tailscale/golang-x-crypto/acme/acme.go new file mode 100644 index 0000000000..b2fd47deb6 --- /dev/null +++ b/vendor/github.com/tailscale/golang-x-crypto/acme/acme.go @@ -0,0 +1,900 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package acme provides an implementation of the +// Automatic Certificate Management Environment (ACME) spec, +// most famously used by Let's Encrypt. +// +// The initial implementation of this package was based on an early version +// of the spec. The current implementation supports only the modern +// RFC 8555 but some of the old API surface remains for compatibility. +// While code using the old API will still compile, it will return an error. +// Note the deprecation comments to update your code. +// +// See https://tools.ietf.org/html/rfc8555 for the spec. +// +// Most common scenarios will want to use autocert subdirectory instead, +// which provides automatic access to certificates from Let's Encrypt +// and any other ACME-based CA. +package acme + +import ( + "context" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/base64" + "encoding/hex" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "math/big" + "net/http" + "strings" + "sync" + "time" +) + +const ( + // LetsEncryptURL is the Directory endpoint of Let's Encrypt CA. + LetsEncryptURL = "https://acme-v02.api.letsencrypt.org/directory" + + // ALPNProto is the ALPN protocol name used by a CA server when validating + // tls-alpn-01 challenges. + // + // Package users must ensure their servers can negotiate the ACME ALPN in + // order for tls-alpn-01 challenge verifications to succeed. + // See the crypto/tls package's Config.NextProtos field. + ALPNProto = "acme-tls/1" +) + +// idPeACMEIdentifier is the OID for the ACME extension for the TLS-ALPN challenge. +// https://tools.ietf.org/html/draft-ietf-acme-tls-alpn-05#section-5.1 +var idPeACMEIdentifier = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31} + +const ( + maxChainLen = 5 // max depth and breadth of a certificate chain + maxCertSize = 1 << 20 // max size of a certificate, in DER bytes + // Used for decoding certs from application/pem-certificate-chain response, + // the default when in RFC mode. + maxCertChainSize = maxCertSize * maxChainLen + + // Max number of collected nonces kept in memory. + // Expect usual peak of 1 or 2. + maxNonces = 100 +) + +// Client is an ACME client. +// +// The only required field is Key. An example of creating a client with a new key +// is as follows: +// +// key, err := rsa.GenerateKey(rand.Reader, 2048) +// if err != nil { +// log.Fatal(err) +// } +// client := &Client{Key: key} +type Client struct { + // Key is the account key used to register with a CA and sign requests. + // Key.Public() must return a *rsa.PublicKey or *ecdsa.PublicKey. + // + // The following algorithms are supported: + // RS256, ES256, ES384 and ES512. + // See RFC 7518 for more details about the algorithms. + Key crypto.Signer + + // HTTPClient optionally specifies an HTTP client to use + // instead of http.DefaultClient. + HTTPClient *http.Client + + // DirectoryURL points to the CA directory endpoint. + // If empty, LetsEncryptURL is used. + // Mutating this value after a successful call of Client's Discover method + // will have no effect. + DirectoryURL string + + // RetryBackoff computes the duration after which the nth retry of a failed request + // should occur. The value of n for the first call on failure is 1. + // The values of r and resp are the request and response of the last failed attempt. + // If the returned value is negative or zero, no more retries are done and an error + // is returned to the caller of the original method. + // + // Requests which result in a 4xx client error are not retried, + // except for 400 Bad Request due to "bad nonce" errors and 429 Too Many Requests. + // + // If RetryBackoff is nil, a truncated exponential backoff algorithm + // with the ceiling of 10 seconds is used, where each subsequent retry n + // is done after either ("Retry-After" + jitter) or (2^n seconds + jitter), + // preferring the former if "Retry-After" header is found in the resp. + // The jitter is a random value up to 1 second. + RetryBackoff func(n int, r *http.Request, resp *http.Response) time.Duration + + // UserAgent is prepended to the User-Agent header sent to the ACME server, + // which by default is this package's name and version. + // + // Reusable libraries and tools in particular should set this value to be + // identifiable by the server, in case they are causing issues. + UserAgent string + + cacheMu sync.Mutex + dir *Directory // cached result of Client's Discover method + // KID is the key identifier provided by the CA. If not provided it will be + // retrieved from the CA by making a call to the registration endpoint. + KID KeyID + + noncesMu sync.Mutex + nonces map[string]struct{} // nonces collected from previous responses +} + +// accountKID returns a key ID associated with c.Key, the account identity +// provided by the CA during RFC based registration. +// It assumes c.Discover has already been called. +// +// accountKID requires at most one network roundtrip. +// It caches only successful result. +// +// When in pre-RFC mode or when c.getRegRFC responds with an error, accountKID +// returns noKeyID. +func (c *Client) accountKID(ctx context.Context) KeyID { + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + if c.KID != noKeyID { + return c.KID + } + a, err := c.getRegRFC(ctx) + if err != nil { + return noKeyID + } + c.KID = KeyID(a.URI) + return c.KID +} + +var errPreRFC = errors.New("acme: server does not support the RFC 8555 version of ACME") + +// Discover performs ACME server discovery using c.DirectoryURL. +// +// It caches successful result. So, subsequent calls will not result in +// a network round-trip. This also means mutating c.DirectoryURL after successful call +// of this method will have no effect. +func (c *Client) Discover(ctx context.Context) (Directory, error) { + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + if c.dir != nil { + return *c.dir, nil + } + + res, err := c.get(ctx, c.directoryURL(), wantStatus(http.StatusOK)) + if err != nil { + return Directory{}, err + } + defer res.Body.Close() + c.addNonce(res.Header) + + var v struct { + Reg string `json:"newAccount"` + Authz string `json:"newAuthz"` + Order string `json:"newOrder"` + Revoke string `json:"revokeCert"` + Nonce string `json:"newNonce"` + KeyChange string `json:"keyChange"` + RenewalInfo string `json:"renewalInfo"` + Meta struct { + Terms string `json:"termsOfService"` + Website string `json:"website"` + CAA []string `json:"caaIdentities"` + ExternalAcct bool `json:"externalAccountRequired"` + } + } + if err := json.NewDecoder(res.Body).Decode(&v); err != nil { + return Directory{}, err + } + if v.Order == "" { + return Directory{}, errPreRFC + } + c.dir = &Directory{ + RegURL: v.Reg, + AuthzURL: v.Authz, + OrderURL: v.Order, + RevokeURL: v.Revoke, + NonceURL: v.Nonce, + KeyChangeURL: v.KeyChange, + RenewalInfoURL: v.RenewalInfo, + Terms: v.Meta.Terms, + Website: v.Meta.Website, + CAA: v.Meta.CAA, + ExternalAccountRequired: v.Meta.ExternalAcct, + } + return *c.dir, nil +} + +func (c *Client) directoryURL() string { + if c.DirectoryURL != "" { + return c.DirectoryURL + } + return LetsEncryptURL +} + +// CreateCert was part of the old version of ACME. It is incompatible with RFC 8555. +// +// Deprecated: this was for the pre-RFC 8555 version of ACME. Callers should use CreateOrderCert. +func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration, bundle bool) (der [][]byte, certURL string, err error) { + return nil, "", errPreRFC +} + +// FetchCert retrieves already issued certificate from the given url, in DER format. +// It retries the request until the certificate is successfully retrieved, +// context is cancelled by the caller or an error response is received. +// +// If the bundle argument is true, the returned value also contains the CA (issuer) +// certificate chain. +// +// FetchCert returns an error if the CA's response or chain was unreasonably large. +// Callers are encouraged to parse the returned value to ensure the certificate is valid +// and has expected features. +func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]byte, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + return c.fetchCertRFC(ctx, url, bundle) +} + +// RevokeCert revokes a previously issued certificate cert, provided in DER format. +// +// The key argument, used to sign the request, must be authorized +// to revoke the certificate. It's up to the CA to decide which keys are authorized. +// For instance, the key pair of the certificate may be authorized. +// If the key is nil, c.Key is used instead. +func (c *Client) RevokeCert(ctx context.Context, key crypto.Signer, cert []byte, reason CRLReasonCode) error { + if _, err := c.Discover(ctx); err != nil { + return err + } + return c.revokeCertRFC(ctx, key, cert, reason) +} + +// FetchRenewalInfo retrieves the RenewalInfo from Directory.RenewalInfoURL. +func (c *Client) FetchRenewalInfo(ctx context.Context, leaf, issuer []byte) (*RenewalInfo, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + + parsedLeaf, err := x509.ParseCertificate(leaf) + if err != nil { + return nil, fmt.Errorf("parsing leaf certificate: %w", err) + } + parsedIssuer, err := x509.ParseCertificate(issuer) + if err != nil { + return nil, fmt.Errorf("parsing issuer certificate: %w", err) + } + + renewalURL, err := c.getRenewalURL(parsedLeaf, parsedIssuer) + if err != nil { + return nil, fmt.Errorf("generating renewal info URL: %w", err) + } + + res, err := c.get(ctx, renewalURL, wantStatus(http.StatusOK)) + if err != nil { + return nil, fmt.Errorf("fetching renewal info: %w", err) + } + defer res.Body.Close() + + var info RenewalInfo + if err := json.NewDecoder(res.Body).Decode(&info); err != nil { + return nil, fmt.Errorf("parsing renewal info response: %w", err) + } + return &info, nil +} + +func (c *Client) getRenewalURL(cert, issuer *x509.Certificate) (string, error) { + // See https://www.ietf.org/archive/id/draft-ietf-acme-ari-01.html#name-getting-renewal-information + // for how the request URL is built. + var publicKeyInfo struct { + Algorithm pkix.AlgorithmIdentifier + PublicKey asn1.BitString + } + if _, err := asn1.Unmarshal(issuer.RawSubjectPublicKeyInfo, &publicKeyInfo); err != nil { + return "", fmt.Errorf("parsing RawSubjectPublicKeyInfo of the issuer certificate: %w", err) + } + + h := crypto.SHA256.New() + h.Write(publicKeyInfo.PublicKey.RightAlign()) + issuerKeyHash := h.Sum(nil) + + h.Reset() + h.Write(issuer.RawSubject) + issuerNameHash := h.Sum(nil) + + // CertID ASN1 structure defined in + // https://datatracker.ietf.org/doc/html/rfc6960#section-4.1.1 + certID, err := asn1.Marshal(struct { + HashAlgorithm pkix.AlgorithmIdentifier + NameHash []byte + IssuerKeyHash []byte + SerialNumber *big.Int + }{ + pkix.AlgorithmIdentifier{ + // SHA256 OID + Algorithm: asn1.ObjectIdentifier([]int{2, 16, 840, 1, 101, 3, 4, 2, 1}), + Parameters: asn1.RawValue{Tag: 5 /* ASN.1 NULL */}, + }, + issuerNameHash, + issuerKeyHash, + cert.SerialNumber, + }) + if err != nil { + return "", fmt.Errorf("marshaling CertID: %w", err) + } + + url := c.dir.RenewalInfoURL + if !strings.HasSuffix(url, "/") { + url += "/" + } + return url + base64.RawURLEncoding.EncodeToString(certID), nil +} + +// AcceptTOS always returns true to indicate the acceptance of a CA's Terms of Service +// during account registration. See Register method of Client for more details. +func AcceptTOS(tosURL string) bool { return true } + +// Register creates a new account with the CA using c.Key. +// It returns the registered account. The account acct is not modified. +// +// The registration may require the caller to agree to the CA's Terms of Service (TOS). +// If so, and the account has not indicated the acceptance of the terms (see Account for details), +// Register calls prompt with a TOS URL provided by the CA. Prompt should report +// whether the caller agrees to the terms. To always accept the terms, the caller can use AcceptTOS. +// +// When interfacing with an RFC-compliant CA, non-RFC 8555 fields of acct are ignored +// and prompt is called if Directory's Terms field is non-zero. +// Also see Error's Instance field for when a CA requires already registered accounts to agree +// to an updated Terms of Service. +func (c *Client) Register(ctx context.Context, acct *Account, prompt func(tosURL string) bool) (*Account, error) { + if c.Key == nil { + return nil, errors.New("acme: client.Key must be set to Register") + } + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + return c.registerRFC(ctx, acct, prompt) +} + +// GetReg retrieves an existing account associated with c.Key. +// +// The url argument is a legacy artifact of the pre-RFC 8555 API +// and is ignored. +func (c *Client) GetReg(ctx context.Context, url string) (*Account, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + return c.getRegRFC(ctx) +} + +// UpdateReg updates an existing registration. +// It returns an updated account copy. The provided account is not modified. +// +// The account's URI is ignored and the account URL associated with +// c.Key is used instead. +func (c *Client) UpdateReg(ctx context.Context, acct *Account) (*Account, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + return c.updateRegRFC(ctx, acct) +} + +// AccountKeyRollover attempts to transition a client's account key to a new key. +// On success client's Key is updated which is not concurrency safe. +// On failure an error will be returned. +// The new key is already registered with the ACME provider if the following is true: +// - error is of type acme.Error +// - StatusCode should be 409 (Conflict) +// - Location header will have the KID of the associated account +// +// More about account key rollover can be found at +// https://tools.ietf.org/html/rfc8555#section-7.3.5. +func (c *Client) AccountKeyRollover(ctx context.Context, newKey crypto.Signer) error { + return c.accountKeyRollover(ctx, newKey) +} + +// Authorize performs the initial step in the pre-authorization flow, +// as opposed to order-based flow. +// The caller will then need to choose from and perform a set of returned +// challenges using c.Accept in order to successfully complete authorization. +// +// Once complete, the caller can use AuthorizeOrder which the CA +// should provision with the already satisfied authorization. +// For pre-RFC CAs, the caller can proceed directly to requesting a certificate +// using CreateCert method. +// +// If an authorization has been previously granted, the CA may return +// a valid authorization which has its Status field set to StatusValid. +// +// More about pre-authorization can be found at +// https://tools.ietf.org/html/rfc8555#section-7.4.1. +func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization, error) { + return c.authorize(ctx, "dns", domain) +} + +// AuthorizeIP is the same as Authorize but requests IP address authorization. +// Clients which successfully obtain such authorization may request to issue +// a certificate for IP addresses. +// +// See the ACME spec extension for more details about IP address identifiers: +// https://tools.ietf.org/html/draft-ietf-acme-ip. +func (c *Client) AuthorizeIP(ctx context.Context, ipaddr string) (*Authorization, error) { + return c.authorize(ctx, "ip", ipaddr) +} + +func (c *Client) authorize(ctx context.Context, typ, val string) (*Authorization, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + + type authzID struct { + Type string `json:"type"` + Value string `json:"value"` + } + req := struct { + Resource string `json:"resource"` + Identifier authzID `json:"identifier"` + }{ + Resource: "new-authz", + Identifier: authzID{Type: typ, Value: val}, + } + res, err := c.post(ctx, nil, c.dir.AuthzURL, req, wantStatus(http.StatusCreated)) + if err != nil { + return nil, err + } + defer res.Body.Close() + + var v wireAuthz + if err := json.NewDecoder(res.Body).Decode(&v); err != nil { + return nil, fmt.Errorf("acme: invalid response: %v", err) + } + if v.Status != StatusPending && v.Status != StatusValid { + return nil, fmt.Errorf("acme: unexpected status: %s", v.Status) + } + return v.authorization(res.Header.Get("Location")), nil +} + +// GetAuthorization retrieves an authorization identified by the given URL. +// +// If a caller needs to poll an authorization until its status is final, +// see the WaitAuthorization method. +func (c *Client) GetAuthorization(ctx context.Context, url string) (*Authorization, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + + res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK)) + if err != nil { + return nil, err + } + defer res.Body.Close() + var v wireAuthz + if err := json.NewDecoder(res.Body).Decode(&v); err != nil { + return nil, fmt.Errorf("acme: invalid response: %v", err) + } + return v.authorization(url), nil +} + +// RevokeAuthorization relinquishes an existing authorization identified +// by the given URL. +// The url argument is an Authorization.URI value. +// +// If successful, the caller will be required to obtain a new authorization +// using the Authorize or AuthorizeOrder methods before being able to request +// a new certificate for the domain associated with the authorization. +// +// It does not revoke existing certificates. +func (c *Client) RevokeAuthorization(ctx context.Context, url string) error { + if _, err := c.Discover(ctx); err != nil { + return err + } + + req := struct { + Resource string `json:"resource"` + Status string `json:"status"` + Delete bool `json:"delete"` + }{ + Resource: "authz", + Status: "deactivated", + Delete: true, + } + res, err := c.post(ctx, nil, url, req, wantStatus(http.StatusOK)) + if err != nil { + return err + } + defer res.Body.Close() + return nil +} + +// WaitAuthorization polls an authorization at the given URL +// until it is in one of the final states, StatusValid or StatusInvalid, +// the ACME CA responded with a 4xx error code, or the context is done. +// +// It returns a non-nil Authorization only if its Status is StatusValid. +// In all other cases WaitAuthorization returns an error. +// If the Status is StatusInvalid, the returned error is of type *AuthorizationError. +func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorization, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + for { + res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK, http.StatusAccepted)) + if err != nil { + return nil, err + } + + var raw wireAuthz + err = json.NewDecoder(res.Body).Decode(&raw) + res.Body.Close() + switch { + case err != nil: + // Skip and retry. + case raw.Status == StatusValid: + return raw.authorization(url), nil + case raw.Status == StatusInvalid: + return nil, raw.error(url) + } + + // Exponential backoff is implemented in c.get above. + // This is just to prevent continuously hitting the CA + // while waiting for a final authorization status. + d := retryAfter(res.Header.Get("Retry-After")) + if d == 0 { + // Given that the fastest challenges TLS-SNI and HTTP-01 + // require a CA to make at least 1 network round trip + // and most likely persist a challenge state, + // this default delay seems reasonable. + d = time.Second + } + t := time.NewTimer(d) + select { + case <-ctx.Done(): + t.Stop() + return nil, ctx.Err() + case <-t.C: + // Retry. + } + } +} + +// GetChallenge retrieves the current status of an challenge. +// +// A client typically polls a challenge status using this method. +func (c *Client) GetChallenge(ctx context.Context, url string) (*Challenge, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + + res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK, http.StatusAccepted)) + if err != nil { + return nil, err + } + + defer res.Body.Close() + v := wireChallenge{URI: url} + if err := json.NewDecoder(res.Body).Decode(&v); err != nil { + return nil, fmt.Errorf("acme: invalid response: %v", err) + } + return v.challenge(), nil +} + +// Accept informs the server that the client accepts one of its challenges +// previously obtained with c.Authorize. +// +// The server will then perform the validation asynchronously. +func (c *Client) Accept(ctx context.Context, chal *Challenge) (*Challenge, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + + res, err := c.post(ctx, nil, chal.URI, json.RawMessage("{}"), wantStatus( + http.StatusOK, // according to the spec + http.StatusAccepted, // Let's Encrypt: see https://goo.gl/WsJ7VT (acme-divergences.md) + )) + if err != nil { + return nil, err + } + defer res.Body.Close() + + var v wireChallenge + if err := json.NewDecoder(res.Body).Decode(&v); err != nil { + return nil, fmt.Errorf("acme: invalid response: %v", err) + } + return v.challenge(), nil +} + +// DNS01ChallengeRecord returns a DNS record value for a dns-01 challenge response. +// A TXT record containing the returned value must be provisioned under +// "_acme-challenge" name of the domain being validated. +// +// The token argument is a Challenge.Token value. +func (c *Client) DNS01ChallengeRecord(token string) (string, error) { + ka, err := keyAuth(c.Key.Public(), token) + if err != nil { + return "", err + } + b := sha256.Sum256([]byte(ka)) + return base64.RawURLEncoding.EncodeToString(b[:]), nil +} + +// HTTP01ChallengeResponse returns the response for an http-01 challenge. +// Servers should respond with the value to HTTP requests at the URL path +// provided by HTTP01ChallengePath to validate the challenge and prove control +// over a domain name. +// +// The token argument is a Challenge.Token value. +func (c *Client) HTTP01ChallengeResponse(token string) (string, error) { + return keyAuth(c.Key.Public(), token) +} + +// HTTP01ChallengePath returns the URL path at which the response for an http-01 challenge +// should be provided by the servers. +// The response value can be obtained with HTTP01ChallengeResponse. +// +// The token argument is a Challenge.Token value. +func (c *Client) HTTP01ChallengePath(token string) string { + return "/.well-known/acme-challenge/" + token +} + +// TLSSNI01ChallengeCert creates a certificate for TLS-SNI-01 challenge response. +// +// Deprecated: This challenge type is unused in both draft-02 and RFC versions of the ACME spec. +func (c *Client) TLSSNI01ChallengeCert(token string, opt ...CertOption) (cert tls.Certificate, name string, err error) { + ka, err := keyAuth(c.Key.Public(), token) + if err != nil { + return tls.Certificate{}, "", err + } + b := sha256.Sum256([]byte(ka)) + h := hex.EncodeToString(b[:]) + name = fmt.Sprintf("%s.%s.acme.invalid", h[:32], h[32:]) + cert, err = tlsChallengeCert([]string{name}, opt) + if err != nil { + return tls.Certificate{}, "", err + } + return cert, name, nil +} + +// TLSSNI02ChallengeCert creates a certificate for TLS-SNI-02 challenge response. +// +// Deprecated: This challenge type is unused in both draft-02 and RFC versions of the ACME spec. +func (c *Client) TLSSNI02ChallengeCert(token string, opt ...CertOption) (cert tls.Certificate, name string, err error) { + b := sha256.Sum256([]byte(token)) + h := hex.EncodeToString(b[:]) + sanA := fmt.Sprintf("%s.%s.token.acme.invalid", h[:32], h[32:]) + + ka, err := keyAuth(c.Key.Public(), token) + if err != nil { + return tls.Certificate{}, "", err + } + b = sha256.Sum256([]byte(ka)) + h = hex.EncodeToString(b[:]) + sanB := fmt.Sprintf("%s.%s.ka.acme.invalid", h[:32], h[32:]) + + cert, err = tlsChallengeCert([]string{sanA, sanB}, opt) + if err != nil { + return tls.Certificate{}, "", err + } + return cert, sanA, nil +} + +// TLSALPN01ChallengeCert creates a certificate for TLS-ALPN-01 challenge response. +// Servers can present the certificate to validate the challenge and prove control +// over a domain name. For more details on TLS-ALPN-01 see +// https://tools.ietf.org/html/draft-shoemaker-acme-tls-alpn-00#section-3 +// +// The token argument is a Challenge.Token value. +// If a WithKey option is provided, its private part signs the returned cert, +// and the public part is used to specify the signee. +// If no WithKey option is provided, a new ECDSA key is generated using P-256 curve. +// +// The returned certificate is valid for the next 24 hours and must be presented only when +// the server name in the TLS ClientHello matches the domain, and the special acme-tls/1 ALPN protocol +// has been specified. +func (c *Client) TLSALPN01ChallengeCert(token, domain string, opt ...CertOption) (cert tls.Certificate, err error) { + ka, err := keyAuth(c.Key.Public(), token) + if err != nil { + return tls.Certificate{}, err + } + shasum := sha256.Sum256([]byte(ka)) + extValue, err := asn1.Marshal(shasum[:]) + if err != nil { + return tls.Certificate{}, err + } + acmeExtension := pkix.Extension{ + Id: idPeACMEIdentifier, + Critical: true, + Value: extValue, + } + + tmpl := defaultTLSChallengeCertTemplate() + + var newOpt []CertOption + for _, o := range opt { + switch o := o.(type) { + case *certOptTemplate: + t := *(*x509.Certificate)(o) // shallow copy is ok + tmpl = &t + default: + newOpt = append(newOpt, o) + } + } + tmpl.ExtraExtensions = append(tmpl.ExtraExtensions, acmeExtension) + newOpt = append(newOpt, WithTemplate(tmpl)) + return tlsChallengeCert([]string{domain}, newOpt) +} + +// popNonce returns a nonce value previously stored with c.addNonce +// or fetches a fresh one from c.dir.NonceURL. +// If NonceURL is empty, it first tries c.directoryURL() and, failing that, +// the provided url. +func (c *Client) popNonce(ctx context.Context, url string) (string, error) { + c.noncesMu.Lock() + defer c.noncesMu.Unlock() + if len(c.nonces) == 0 { + if c.dir != nil && c.dir.NonceURL != "" { + return c.fetchNonce(ctx, c.dir.NonceURL) + } + dirURL := c.directoryURL() + v, err := c.fetchNonce(ctx, dirURL) + if err != nil && url != dirURL { + v, err = c.fetchNonce(ctx, url) + } + return v, err + } + var nonce string + for nonce = range c.nonces { + delete(c.nonces, nonce) + break + } + return nonce, nil +} + +// clearNonces clears any stored nonces +func (c *Client) clearNonces() { + c.noncesMu.Lock() + defer c.noncesMu.Unlock() + c.nonces = make(map[string]struct{}) +} + +// addNonce stores a nonce value found in h (if any) for future use. +func (c *Client) addNonce(h http.Header) { + v := nonceFromHeader(h) + if v == "" { + return + } + c.noncesMu.Lock() + defer c.noncesMu.Unlock() + if len(c.nonces) >= maxNonces { + return + } + if c.nonces == nil { + c.nonces = make(map[string]struct{}) + } + c.nonces[v] = struct{}{} +} + +func (c *Client) fetchNonce(ctx context.Context, url string) (string, error) { + r, err := http.NewRequest("HEAD", url, nil) + if err != nil { + return "", err + } + resp, err := c.doNoRetry(ctx, r) + if err != nil { + return "", err + } + defer resp.Body.Close() + nonce := nonceFromHeader(resp.Header) + if nonce == "" { + if resp.StatusCode > 299 { + return "", responseError(resp) + } + return "", errors.New("acme: nonce not found") + } + return nonce, nil +} + +func nonceFromHeader(h http.Header) string { + return h.Get("Replay-Nonce") +} + +// linkHeader returns URI-Reference values of all Link headers +// with relation-type rel. +// See https://tools.ietf.org/html/rfc5988#section-5 for details. +func linkHeader(h http.Header, rel string) []string { + var links []string + for _, v := range h["Link"] { + parts := strings.Split(v, ";") + for _, p := range parts { + p = strings.TrimSpace(p) + if !strings.HasPrefix(p, "rel=") { + continue + } + if v := strings.Trim(p[4:], `"`); v == rel { + links = append(links, strings.Trim(parts[0], "<>")) + } + } + } + return links +} + +// keyAuth generates a key authorization string for a given token. +func keyAuth(pub crypto.PublicKey, token string) (string, error) { + th, err := JWKThumbprint(pub) + if err != nil { + return "", err + } + return fmt.Sprintf("%s.%s", token, th), nil +} + +// defaultTLSChallengeCertTemplate is a template used to create challenge certs for TLS challenges. +func defaultTLSChallengeCertTemplate() *x509.Certificate { + return &x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + BasicConstraintsValid: true, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } +} + +// tlsChallengeCert creates a temporary certificate for TLS-SNI challenges +// with the given SANs and auto-generated public/private key pair. +// The Subject Common Name is set to the first SAN to aid debugging. +// To create a cert with a custom key pair, specify WithKey option. +func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) { + var key crypto.Signer + tmpl := defaultTLSChallengeCertTemplate() + for _, o := range opt { + switch o := o.(type) { + case *certOptKey: + if key != nil { + return tls.Certificate{}, errors.New("acme: duplicate key option") + } + key = o.key + case *certOptTemplate: + t := *(*x509.Certificate)(o) // shallow copy is ok + tmpl = &t + default: + // package's fault, if we let this happen: + panic(fmt.Sprintf("unsupported option type %T", o)) + } + } + if key == nil { + var err error + if key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader); err != nil { + return tls.Certificate{}, err + } + } + tmpl.DNSNames = san + if len(san) > 0 { + tmpl.Subject.CommonName = san[0] + } + + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key) + if err != nil { + return tls.Certificate{}, err + } + return tls.Certificate{ + Certificate: [][]byte{der}, + PrivateKey: key, + }, nil +} + +// encodePEM returns b encoded as PEM with block of type typ. +func encodePEM(typ string, b []byte) []byte { + pb := &pem.Block{Type: typ, Bytes: b} + return pem.EncodeToMemory(pb) +} + +// timeNow is time.Now, except in tests which can mess with it. +var timeNow = time.Now diff --git a/vendor/github.com/tailscale/golang-x-crypto/acme/http.go b/vendor/github.com/tailscale/golang-x-crypto/acme/http.go new file mode 100644 index 0000000000..58836e5d30 --- /dev/null +++ b/vendor/github.com/tailscale/golang-x-crypto/acme/http.go @@ -0,0 +1,325 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "bytes" + "context" + "crypto" + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "io" + "math/big" + "net/http" + "strconv" + "strings" + "time" +) + +// retryTimer encapsulates common logic for retrying unsuccessful requests. +// It is not safe for concurrent use. +type retryTimer struct { + // backoffFn provides backoff delay sequence for retries. + // See Client.RetryBackoff doc comment. + backoffFn func(n int, r *http.Request, res *http.Response) time.Duration + // n is the current retry attempt. + n int +} + +func (t *retryTimer) inc() { + t.n++ +} + +// backoff pauses the current goroutine as described in Client.RetryBackoff. +func (t *retryTimer) backoff(ctx context.Context, r *http.Request, res *http.Response) error { + d := t.backoffFn(t.n, r, res) + if d <= 0 { + return fmt.Errorf("acme: no more retries for %s; tried %d time(s)", r.URL, t.n) + } + wakeup := time.NewTimer(d) + defer wakeup.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-wakeup.C: + return nil + } +} + +func (c *Client) retryTimer() *retryTimer { + f := c.RetryBackoff + if f == nil { + f = defaultBackoff + } + return &retryTimer{backoffFn: f} +} + +// defaultBackoff provides default Client.RetryBackoff implementation +// using a truncated exponential backoff algorithm, +// as described in Client.RetryBackoff. +// +// The n argument is always bounded between 1 and 30. +// The returned value is always greater than 0. +func defaultBackoff(n int, r *http.Request, res *http.Response) time.Duration { + const max = 10 * time.Second + var jitter time.Duration + if x, err := rand.Int(rand.Reader, big.NewInt(1000)); err == nil { + // Set the minimum to 1ms to avoid a case where + // an invalid Retry-After value is parsed into 0 below, + // resulting in the 0 returned value which would unintentionally + // stop the retries. + jitter = (1 + time.Duration(x.Int64())) * time.Millisecond + } + if v, ok := res.Header["Retry-After"]; ok { + return retryAfter(v[0]) + jitter + } + + if n < 1 { + n = 1 + } + if n > 30 { + n = 30 + } + d := time.Duration(1< max { + return max + } + return d +} + +// retryAfter parses a Retry-After HTTP header value, +// trying to convert v into an int (seconds) or use http.ParseTime otherwise. +// It returns zero value if v cannot be parsed. +func retryAfter(v string) time.Duration { + if i, err := strconv.Atoi(v); err == nil { + return time.Duration(i) * time.Second + } + t, err := http.ParseTime(v) + if err != nil { + return 0 + } + return t.Sub(timeNow()) +} + +// resOkay is a function that reports whether the provided response is okay. +// It is expected to keep the response body unread. +type resOkay func(*http.Response) bool + +// wantStatus returns a function which reports whether the code +// matches the status code of a response. +func wantStatus(codes ...int) resOkay { + return func(res *http.Response) bool { + for _, code := range codes { + if code == res.StatusCode { + return true + } + } + return false + } +} + +// get issues an unsigned GET request to the specified URL. +// It returns a non-error value only when ok reports true. +// +// get retries unsuccessful attempts according to c.RetryBackoff +// until the context is done or a non-retriable error is received. +func (c *Client) get(ctx context.Context, url string, ok resOkay) (*http.Response, error) { + retry := c.retryTimer() + for { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + res, err := c.doNoRetry(ctx, req) + switch { + case err != nil: + return nil, err + case ok(res): + return res, nil + case isRetriable(res.StatusCode): + retry.inc() + resErr := responseError(res) + res.Body.Close() + // Ignore the error value from retry.backoff + // and return the one from last retry, as received from the CA. + if retry.backoff(ctx, req, res) != nil { + return nil, resErr + } + default: + defer res.Body.Close() + return nil, responseError(res) + } + } +} + +// postAsGet is POST-as-GET, a replacement for GET in RFC 8555 +// as described in https://tools.ietf.org/html/rfc8555#section-6.3. +// It makes a POST request in KID form with zero JWS payload. +// See nopayload doc comments in jws.go. +func (c *Client) postAsGet(ctx context.Context, url string, ok resOkay) (*http.Response, error) { + return c.post(ctx, nil, url, noPayload, ok) +} + +// post issues a signed POST request in JWS format using the provided key +// to the specified URL. If key is nil, c.Key is used instead. +// It returns a non-error value only when ok reports true. +// +// post retries unsuccessful attempts according to c.RetryBackoff +// until the context is done or a non-retriable error is received. +// It uses postNoRetry to make individual requests. +func (c *Client) post(ctx context.Context, key crypto.Signer, url string, body interface{}, ok resOkay) (*http.Response, error) { + retry := c.retryTimer() + for { + res, req, err := c.postNoRetry(ctx, key, url, body) + if err != nil { + return nil, err + } + if ok(res) { + return res, nil + } + resErr := responseError(res) + res.Body.Close() + switch { + // Check for bad nonce before isRetriable because it may have been returned + // with an unretriable response code such as 400 Bad Request. + case isBadNonce(resErr): + // Consider any previously stored nonce values to be invalid. + c.clearNonces() + case !isRetriable(res.StatusCode): + return nil, resErr + } + retry.inc() + // Ignore the error value from retry.backoff + // and return the one from last retry, as received from the CA. + if err := retry.backoff(ctx, req, res); err != nil { + return nil, resErr + } + } +} + +// postNoRetry signs the body with the given key and POSTs it to the provided url. +// It is used by c.post to retry unsuccessful attempts. +// The body argument must be JSON-serializable. +// +// If key argument is nil, c.Key is used to sign the request. +// If key argument is nil and c.accountKID returns a non-zero keyID, +// the request is sent in KID form. Otherwise, JWK form is used. +// +// In practice, when interfacing with RFC-compliant CAs most requests are sent in KID form +// and JWK is used only when KID is unavailable: new account endpoint and certificate +// revocation requests authenticated by a cert key. +// See jwsEncodeJSON for other details. +func (c *Client) postNoRetry(ctx context.Context, key crypto.Signer, url string, body interface{}) (*http.Response, *http.Request, error) { + kid := noKeyID + if key == nil { + if c.Key == nil { + return nil, nil, errors.New("acme: Client.Key must be populated to make POST requests") + } + key = c.Key + kid = c.accountKID(ctx) + } + nonce, err := c.popNonce(ctx, url) + if err != nil { + return nil, nil, err + } + b, err := jwsEncodeJSON(body, key, kid, nonce, url) + if err != nil { + return nil, nil, err + } + req, err := http.NewRequest("POST", url, bytes.NewReader(b)) + if err != nil { + return nil, nil, err + } + req.Header.Set("Content-Type", "application/jose+json") + res, err := c.doNoRetry(ctx, req) + if err != nil { + return nil, nil, err + } + c.addNonce(res.Header) + return res, req, nil +} + +// doNoRetry issues a request req, replacing its context (if any) with ctx. +func (c *Client) doNoRetry(ctx context.Context, req *http.Request) (*http.Response, error) { + req.Header.Set("User-Agent", c.userAgent()) + res, err := c.httpClient().Do(req.WithContext(ctx)) + if err != nil { + select { + case <-ctx.Done(): + // Prefer the unadorned context error. + // (The acme package had tests assuming this, previously from ctxhttp's + // behavior, predating net/http supporting contexts natively) + // TODO(bradfitz): reconsider this in the future. But for now this + // requires no test updates. + return nil, ctx.Err() + default: + return nil, err + } + } + return res, nil +} + +func (c *Client) httpClient() *http.Client { + if c.HTTPClient != nil { + return c.HTTPClient + } + return http.DefaultClient +} + +// packageVersion is the version of the module that contains this package, for +// sending as part of the User-Agent header. It's set in version_go112.go. +var packageVersion string + +// userAgent returns the User-Agent header value. It includes the package name, +// the module version (if available), and the c.UserAgent value (if set). +func (c *Client) userAgent() string { + ua := "golang.org/x/crypto/acme" + if packageVersion != "" { + ua += "@" + packageVersion + } + if c.UserAgent != "" { + ua = c.UserAgent + " " + ua + } + return ua +} + +// isBadNonce reports whether err is an ACME "badnonce" error. +func isBadNonce(err error) bool { + // According to the spec badNonce is urn:ietf:params:acme:error:badNonce. + // However, ACME servers in the wild return their versions of the error. + // See https://tools.ietf.org/html/draft-ietf-acme-acme-02#section-5.4 + // and https://github.com/letsencrypt/boulder/blob/0e07eacb/docs/acme-divergences.md#section-66. + ae, ok := err.(*Error) + return ok && strings.HasSuffix(strings.ToLower(ae.ProblemType), ":badnonce") +} + +// isRetriable reports whether a request can be retried +// based on the response status code. +// +// Note that a "bad nonce" error is returned with a non-retriable 400 Bad Request code. +// Callers should parse the response and check with isBadNonce. +func isRetriable(code int) bool { + return code <= 399 || code >= 500 || code == http.StatusTooManyRequests +} + +// responseError creates an error of Error type from resp. +func responseError(resp *http.Response) error { + // don't care if ReadAll returns an error: + // json.Unmarshal will fail in that case anyway + b, _ := io.ReadAll(resp.Body) + e := &wireError{Status: resp.StatusCode} + if err := json.Unmarshal(b, e); err != nil { + // this is not a regular error response: + // populate detail with anything we received, + // e.Status will already contain HTTP response code value + e.Detail = string(b) + if e.Detail == "" { + e.Detail = resp.Status + } + } + return e.error(resp.Header) +} diff --git a/vendor/github.com/tailscale/golang-x-crypto/acme/jws.go b/vendor/github.com/tailscale/golang-x-crypto/acme/jws.go new file mode 100644 index 0000000000..b38828d859 --- /dev/null +++ b/vendor/github.com/tailscale/golang-x-crypto/acme/jws.go @@ -0,0 +1,257 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "crypto" + "crypto/ecdsa" + "crypto/hmac" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + _ "crypto/sha512" // need for EC keys + "encoding/asn1" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "math/big" +) + +// KeyID is the account key identity provided by a CA during registration. +type KeyID string + +// noKeyID indicates that jwsEncodeJSON should compute and use JWK instead of a KID. +// See jwsEncodeJSON for details. +const noKeyID = KeyID("") + +// noPayload indicates jwsEncodeJSON will encode zero-length octet string +// in a JWS request. This is called POST-as-GET in RFC 8555 and is used to make +// authenticated GET requests via POSTing with an empty payload. +// See https://tools.ietf.org/html/rfc8555#section-6.3 for more details. +const noPayload = "" + +// noNonce indicates that the nonce should be omitted from the protected header. +// See jwsEncodeJSON for details. +const noNonce = "" + +// jsonWebSignature can be easily serialized into a JWS following +// https://tools.ietf.org/html/rfc7515#section-3.2. +type jsonWebSignature struct { + Protected string `json:"protected"` + Payload string `json:"payload"` + Sig string `json:"signature"` +} + +// jwsEncodeJSON signs claimset using provided key and a nonce. +// The result is serialized in JSON format containing either kid or jwk +// fields based on the provided KeyID value. +// +// The claimset is marshalled using json.Marshal unless it is a string. +// In which case it is inserted directly into the message. +// +// If kid is non-empty, its quoted value is inserted in the protected header +// as "kid" field value. Otherwise, JWK is computed using jwkEncode and inserted +// as "jwk" field value. The "jwk" and "kid" fields are mutually exclusive. +// +// If nonce is non-empty, its quoted value is inserted in the protected header. +// +// See https://tools.ietf.org/html/rfc7515#section-7. +func jwsEncodeJSON(claimset interface{}, key crypto.Signer, kid KeyID, nonce, url string) ([]byte, error) { + if key == nil { + return nil, errors.New("nil key") + } + alg, sha := jwsHasher(key.Public()) + if alg == "" || !sha.Available() { + return nil, ErrUnsupportedKey + } + headers := struct { + Alg string `json:"alg"` + KID string `json:"kid,omitempty"` + JWK json.RawMessage `json:"jwk,omitempty"` + Nonce string `json:"nonce,omitempty"` + URL string `json:"url"` + }{ + Alg: alg, + Nonce: nonce, + URL: url, + } + switch kid { + case noKeyID: + jwk, err := jwkEncode(key.Public()) + if err != nil { + return nil, err + } + headers.JWK = json.RawMessage(jwk) + default: + headers.KID = string(kid) + } + phJSON, err := json.Marshal(headers) + if err != nil { + return nil, err + } + phead := base64.RawURLEncoding.EncodeToString([]byte(phJSON)) + var payload string + if val, ok := claimset.(string); ok { + payload = val + } else { + cs, err := json.Marshal(claimset) + if err != nil { + return nil, err + } + payload = base64.RawURLEncoding.EncodeToString(cs) + } + hash := sha.New() + hash.Write([]byte(phead + "." + payload)) + sig, err := jwsSign(key, sha, hash.Sum(nil)) + if err != nil { + return nil, err + } + enc := jsonWebSignature{ + Protected: phead, + Payload: payload, + Sig: base64.RawURLEncoding.EncodeToString(sig), + } + return json.Marshal(&enc) +} + +// jwsWithMAC creates and signs a JWS using the given key and the HS256 +// algorithm. kid and url are included in the protected header. rawPayload +// should not be base64-URL-encoded. +func jwsWithMAC(key []byte, kid, url string, rawPayload []byte) (*jsonWebSignature, error) { + if len(key) == 0 { + return nil, errors.New("acme: cannot sign JWS with an empty MAC key") + } + header := struct { + Algorithm string `json:"alg"` + KID string `json:"kid"` + URL string `json:"url,omitempty"` + }{ + // Only HMAC-SHA256 is supported. + Algorithm: "HS256", + KID: kid, + URL: url, + } + rawProtected, err := json.Marshal(header) + if err != nil { + return nil, err + } + protected := base64.RawURLEncoding.EncodeToString(rawProtected) + payload := base64.RawURLEncoding.EncodeToString(rawPayload) + + h := hmac.New(sha256.New, key) + if _, err := h.Write([]byte(protected + "." + payload)); err != nil { + return nil, err + } + mac := h.Sum(nil) + + return &jsonWebSignature{ + Protected: protected, + Payload: payload, + Sig: base64.RawURLEncoding.EncodeToString(mac), + }, nil +} + +// jwkEncode encodes public part of an RSA or ECDSA key into a JWK. +// The result is also suitable for creating a JWK thumbprint. +// https://tools.ietf.org/html/rfc7517 +func jwkEncode(pub crypto.PublicKey) (string, error) { + switch pub := pub.(type) { + case *rsa.PublicKey: + // https://tools.ietf.org/html/rfc7518#section-6.3.1 + n := pub.N + e := big.NewInt(int64(pub.E)) + // Field order is important. + // See https://tools.ietf.org/html/rfc7638#section-3.3 for details. + return fmt.Sprintf(`{"e":"%s","kty":"RSA","n":"%s"}`, + base64.RawURLEncoding.EncodeToString(e.Bytes()), + base64.RawURLEncoding.EncodeToString(n.Bytes()), + ), nil + case *ecdsa.PublicKey: + // https://tools.ietf.org/html/rfc7518#section-6.2.1 + p := pub.Curve.Params() + n := p.BitSize / 8 + if p.BitSize%8 != 0 { + n++ + } + x := pub.X.Bytes() + if n > len(x) { + x = append(make([]byte, n-len(x)), x...) + } + y := pub.Y.Bytes() + if n > len(y) { + y = append(make([]byte, n-len(y)), y...) + } + // Field order is important. + // See https://tools.ietf.org/html/rfc7638#section-3.3 for details. + return fmt.Sprintf(`{"crv":"%s","kty":"EC","x":"%s","y":"%s"}`, + p.Name, + base64.RawURLEncoding.EncodeToString(x), + base64.RawURLEncoding.EncodeToString(y), + ), nil + } + return "", ErrUnsupportedKey +} + +// jwsSign signs the digest using the given key. +// The hash is unused for ECDSA keys. +func jwsSign(key crypto.Signer, hash crypto.Hash, digest []byte) ([]byte, error) { + switch pub := key.Public().(type) { + case *rsa.PublicKey: + return key.Sign(rand.Reader, digest, hash) + case *ecdsa.PublicKey: + sigASN1, err := key.Sign(rand.Reader, digest, hash) + if err != nil { + return nil, err + } + + var rs struct{ R, S *big.Int } + if _, err := asn1.Unmarshal(sigASN1, &rs); err != nil { + return nil, err + } + + rb, sb := rs.R.Bytes(), rs.S.Bytes() + size := pub.Params().BitSize / 8 + if size%8 > 0 { + size++ + } + sig := make([]byte, size*2) + copy(sig[size-len(rb):], rb) + copy(sig[size*2-len(sb):], sb) + return sig, nil + } + return nil, ErrUnsupportedKey +} + +// jwsHasher indicates suitable JWS algorithm name and a hash function +// to use for signing a digest with the provided key. +// It returns ("", 0) if the key is not supported. +func jwsHasher(pub crypto.PublicKey) (string, crypto.Hash) { + switch pub := pub.(type) { + case *rsa.PublicKey: + return "RS256", crypto.SHA256 + case *ecdsa.PublicKey: + switch pub.Params().Name { + case "P-256": + return "ES256", crypto.SHA256 + case "P-384": + return "ES384", crypto.SHA384 + case "P-521": + return "ES512", crypto.SHA512 + } + } + return "", 0 +} + +// JWKThumbprint creates a JWK thumbprint out of pub +// as specified in https://tools.ietf.org/html/rfc7638. +func JWKThumbprint(pub crypto.PublicKey) (string, error) { + jwk, err := jwkEncode(pub) + if err != nil { + return "", err + } + b := sha256.Sum256([]byte(jwk)) + return base64.RawURLEncoding.EncodeToString(b[:]), nil +} diff --git a/vendor/github.com/tailscale/golang-x-crypto/acme/rfc8555.go b/vendor/github.com/tailscale/golang-x-crypto/acme/rfc8555.go new file mode 100644 index 0000000000..3152e531b6 --- /dev/null +++ b/vendor/github.com/tailscale/golang-x-crypto/acme/rfc8555.go @@ -0,0 +1,476 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "context" + "crypto" + "encoding/base64" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "io" + "net/http" + "time" +) + +// DeactivateReg permanently disables an existing account associated with c.Key. +// A deactivated account can no longer request certificate issuance or access +// resources related to the account, such as orders or authorizations. +// +// It only works with CAs implementing RFC 8555. +func (c *Client) DeactivateReg(ctx context.Context) error { + if _, err := c.Discover(ctx); err != nil { // required by c.accountKID + return err + } + url := string(c.accountKID(ctx)) + if url == "" { + return ErrNoAccount + } + req := json.RawMessage(`{"status": "deactivated"}`) + res, err := c.post(ctx, nil, url, req, wantStatus(http.StatusOK)) + if err != nil { + return err + } + res.Body.Close() + return nil +} + +// registerRFC is equivalent to c.Register but for CAs implementing RFC 8555. +// It expects c.Discover to have already been called. +func (c *Client) registerRFC(ctx context.Context, acct *Account, prompt func(tosURL string) bool) (*Account, error) { + c.cacheMu.Lock() // guard c.kid access + defer c.cacheMu.Unlock() + + req := struct { + TermsAgreed bool `json:"termsOfServiceAgreed,omitempty"` + Contact []string `json:"contact,omitempty"` + ExternalAccountBinding *jsonWebSignature `json:"externalAccountBinding,omitempty"` + }{ + Contact: acct.Contact, + } + if c.dir.Terms != "" { + req.TermsAgreed = prompt(c.dir.Terms) + } + + // set 'externalAccountBinding' field if requested + if acct.ExternalAccountBinding != nil { + eabJWS, err := c.encodeExternalAccountBinding(acct.ExternalAccountBinding) + if err != nil { + return nil, fmt.Errorf("acme: failed to encode external account binding: %v", err) + } + req.ExternalAccountBinding = eabJWS + } + + res, err := c.post(ctx, c.Key, c.dir.RegURL, req, wantStatus( + http.StatusOK, // account with this key already registered + http.StatusCreated, // new account created + )) + if err != nil { + return nil, err + } + + defer res.Body.Close() + a, err := responseAccount(res) + if err != nil { + return nil, err + } + // Cache Account URL even if we return an error to the caller. + // It is by all means a valid and usable "kid" value for future requests. + c.KID = KeyID(a.URI) + if res.StatusCode == http.StatusOK { + return nil, ErrAccountAlreadyExists + } + return a, nil +} + +// encodeExternalAccountBinding will encode an external account binding stanza +// as described in https://tools.ietf.org/html/rfc8555#section-7.3.4. +func (c *Client) encodeExternalAccountBinding(eab *ExternalAccountBinding) (*jsonWebSignature, error) { + jwk, err := jwkEncode(c.Key.Public()) + if err != nil { + return nil, err + } + return jwsWithMAC(eab.Key, eab.KID, c.dir.RegURL, []byte(jwk)) +} + +// updateRegRFC is equivalent to c.UpdateReg but for CAs implementing RFC 8555. +// It expects c.Discover to have already been called. +func (c *Client) updateRegRFC(ctx context.Context, a *Account) (*Account, error) { + url := string(c.accountKID(ctx)) + if url == "" { + return nil, ErrNoAccount + } + req := struct { + Contact []string `json:"contact,omitempty"` + }{ + Contact: a.Contact, + } + res, err := c.post(ctx, nil, url, req, wantStatus(http.StatusOK)) + if err != nil { + return nil, err + } + defer res.Body.Close() + return responseAccount(res) +} + +// getRegRFC is equivalent to c.GetReg but for CAs implementing RFC 8555. +// It expects c.Discover to have already been called. +func (c *Client) getRegRFC(ctx context.Context) (*Account, error) { + req := json.RawMessage(`{"onlyReturnExisting": true}`) + res, err := c.post(ctx, c.Key, c.dir.RegURL, req, wantStatus(http.StatusOK)) + if e, ok := err.(*Error); ok && e.ProblemType == "urn:ietf:params:acme:error:accountDoesNotExist" { + return nil, ErrNoAccount + } + if err != nil { + return nil, err + } + + defer res.Body.Close() + return responseAccount(res) +} + +func responseAccount(res *http.Response) (*Account, error) { + var v struct { + Status string + Contact []string + Orders string + } + if err := json.NewDecoder(res.Body).Decode(&v); err != nil { + return nil, fmt.Errorf("acme: invalid account response: %v", err) + } + return &Account{ + URI: res.Header.Get("Location"), + Status: v.Status, + Contact: v.Contact, + OrdersURL: v.Orders, + }, nil +} + +// accountKeyRollover attempts to perform account key rollover. +// On success it will change client.Key to the new key. +func (c *Client) accountKeyRollover(ctx context.Context, newKey crypto.Signer) error { + dir, err := c.Discover(ctx) // Also required by c.accountKID + if err != nil { + return err + } + kid := c.accountKID(ctx) + if kid == noKeyID { + return ErrNoAccount + } + oldKey, err := jwkEncode(c.Key.Public()) + if err != nil { + return err + } + payload := struct { + Account string `json:"account"` + OldKey json.RawMessage `json:"oldKey"` + }{ + Account: string(kid), + OldKey: json.RawMessage(oldKey), + } + inner, err := jwsEncodeJSON(payload, newKey, noKeyID, noNonce, dir.KeyChangeURL) + if err != nil { + return err + } + + res, err := c.post(ctx, nil, dir.KeyChangeURL, base64.RawURLEncoding.EncodeToString(inner), wantStatus(http.StatusOK)) + if err != nil { + return err + } + defer res.Body.Close() + c.Key = newKey + return nil +} + +// AuthorizeOrder initiates the order-based application for certificate issuance, +// as opposed to pre-authorization in Authorize. +// It is only supported by CAs implementing RFC 8555. +// +// The caller then needs to fetch each authorization with GetAuthorization, +// identify those with StatusPending status and fulfill a challenge using Accept. +// Once all authorizations are satisfied, the caller will typically want to poll +// order status using WaitOrder until it's in StatusReady state. +// To finalize the order and obtain a certificate, the caller submits a CSR with CreateOrderCert. +func (c *Client) AuthorizeOrder(ctx context.Context, id []AuthzID, opt ...OrderOption) (*Order, error) { + dir, err := c.Discover(ctx) + if err != nil { + return nil, err + } + + req := struct { + Identifiers []wireAuthzID `json:"identifiers"` + NotBefore string `json:"notBefore,omitempty"` + NotAfter string `json:"notAfter,omitempty"` + }{} + for _, v := range id { + req.Identifiers = append(req.Identifiers, wireAuthzID{ + Type: v.Type, + Value: v.Value, + }) + } + for _, o := range opt { + switch o := o.(type) { + case orderNotBeforeOpt: + req.NotBefore = time.Time(o).Format(time.RFC3339) + case orderNotAfterOpt: + req.NotAfter = time.Time(o).Format(time.RFC3339) + default: + // Package's fault if we let this happen. + panic(fmt.Sprintf("unsupported order option type %T", o)) + } + } + + res, err := c.post(ctx, nil, dir.OrderURL, req, wantStatus(http.StatusCreated)) + if err != nil { + return nil, err + } + defer res.Body.Close() + return responseOrder(res) +} + +// GetOrder retrives an order identified by the given URL. +// For orders created with AuthorizeOrder, the url value is Order.URI. +// +// If a caller needs to poll an order until its status is final, +// see the WaitOrder method. +func (c *Client) GetOrder(ctx context.Context, url string) (*Order, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + + res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK)) + if err != nil { + return nil, err + } + defer res.Body.Close() + return responseOrder(res) +} + +// WaitOrder polls an order from the given URL until it is in one of the final states, +// StatusReady, StatusValid or StatusInvalid, the CA responded with a non-retryable error +// or the context is done. +// +// It returns a non-nil Order only if its Status is StatusReady or StatusValid. +// In all other cases WaitOrder returns an error. +// If the Status is StatusInvalid, the returned error is of type *OrderError. +func (c *Client) WaitOrder(ctx context.Context, url string) (*Order, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + for { + res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK)) + if err != nil { + return nil, err + } + o, err := responseOrder(res) + res.Body.Close() + switch { + case err != nil: + // Skip and retry. + case o.Status == StatusInvalid: + return nil, &OrderError{OrderURL: o.URI, Status: o.Status} + case o.Status == StatusReady || o.Status == StatusValid: + return o, nil + } + + d := retryAfter(res.Header.Get("Retry-After")) + if d == 0 { + // Default retry-after. + // Same reasoning as in WaitAuthorization. + d = time.Second + } + t := time.NewTimer(d) + select { + case <-ctx.Done(): + t.Stop() + return nil, ctx.Err() + case <-t.C: + // Retry. + } + } +} + +func responseOrder(res *http.Response) (*Order, error) { + var v struct { + Status string + Expires time.Time + Identifiers []wireAuthzID + NotBefore time.Time + NotAfter time.Time + Error *wireError + Authorizations []string + Finalize string + Certificate string + } + if err := json.NewDecoder(res.Body).Decode(&v); err != nil { + return nil, fmt.Errorf("acme: error reading order: %v", err) + } + o := &Order{ + URI: res.Header.Get("Location"), + Status: v.Status, + Expires: v.Expires, + NotBefore: v.NotBefore, + NotAfter: v.NotAfter, + AuthzURLs: v.Authorizations, + FinalizeURL: v.Finalize, + CertURL: v.Certificate, + } + for _, id := range v.Identifiers { + o.Identifiers = append(o.Identifiers, AuthzID{Type: id.Type, Value: id.Value}) + } + if v.Error != nil { + o.Error = v.Error.error(nil /* headers */) + } + return o, nil +} + +// CreateOrderCert submits the CSR (Certificate Signing Request) to a CA at the specified URL. +// The URL is the FinalizeURL field of an Order created with AuthorizeOrder. +// +// If the bundle argument is true, the returned value also contain the CA (issuer) +// certificate chain. Otherwise, only a leaf certificate is returned. +// The returned URL can be used to re-fetch the certificate using FetchCert. +// +// This method is only supported by CAs implementing RFC 8555. See CreateCert for pre-RFC CAs. +// +// CreateOrderCert returns an error if the CA's response is unreasonably large. +// Callers are encouraged to parse the returned value to ensure the certificate is valid and has the expected features. +func (c *Client) CreateOrderCert(ctx context.Context, url string, csr []byte, bundle bool) (der [][]byte, certURL string, err error) { + if _, err := c.Discover(ctx); err != nil { // required by c.accountKID + return nil, "", err + } + + // RFC describes this as "finalize order" request. + req := struct { + CSR string `json:"csr"` + }{ + CSR: base64.RawURLEncoding.EncodeToString(csr), + } + res, err := c.post(ctx, nil, url, req, wantStatus(http.StatusOK)) + if err != nil { + return nil, "", err + } + defer res.Body.Close() + o, err := responseOrder(res) + if err != nil { + return nil, "", err + } + + // Wait for CA to issue the cert if they haven't. + if o.Status != StatusValid { + o, err = c.WaitOrder(ctx, o.URI) + } + if err != nil { + return nil, "", err + } + // The only acceptable status post finalize and WaitOrder is "valid". + if o.Status != StatusValid { + return nil, "", &OrderError{OrderURL: o.URI, Status: o.Status} + } + crt, err := c.fetchCertRFC(ctx, o.CertURL, bundle) + return crt, o.CertURL, err +} + +// fetchCertRFC downloads issued certificate from the given URL. +// It expects the CA to respond with PEM-encoded certificate chain. +// +// The URL argument is the CertURL field of Order. +func (c *Client) fetchCertRFC(ctx context.Context, url string, bundle bool) ([][]byte, error) { + res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK)) + if err != nil { + return nil, err + } + defer res.Body.Close() + + // Get all the bytes up to a sane maximum. + // Account very roughly for base64 overhead. + const max = maxCertChainSize + maxCertChainSize/33 + b, err := io.ReadAll(io.LimitReader(res.Body, max+1)) + if err != nil { + return nil, fmt.Errorf("acme: fetch cert response stream: %v", err) + } + if len(b) > max { + return nil, errors.New("acme: certificate chain is too big") + } + + // Decode PEM chain. + var chain [][]byte + for { + var p *pem.Block + p, b = pem.Decode(b) + if p == nil { + break + } + if p.Type != "CERTIFICATE" { + return nil, fmt.Errorf("acme: invalid PEM cert type %q", p.Type) + } + + chain = append(chain, p.Bytes) + if !bundle { + return chain, nil + } + if len(chain) > maxChainLen { + return nil, errors.New("acme: certificate chain is too long") + } + } + if len(chain) == 0 { + return nil, errors.New("acme: certificate chain is empty") + } + return chain, nil +} + +// sends a cert revocation request in either JWK form when key is non-nil or KID form otherwise. +func (c *Client) revokeCertRFC(ctx context.Context, key crypto.Signer, cert []byte, reason CRLReasonCode) error { + req := &struct { + Cert string `json:"certificate"` + Reason int `json:"reason"` + }{ + Cert: base64.RawURLEncoding.EncodeToString(cert), + Reason: int(reason), + } + res, err := c.post(ctx, key, c.dir.RevokeURL, req, wantStatus(http.StatusOK)) + if err != nil { + if isAlreadyRevoked(err) { + // Assume it is not an error to revoke an already revoked cert. + return nil + } + return err + } + defer res.Body.Close() + return nil +} + +func isAlreadyRevoked(err error) bool { + e, ok := err.(*Error) + return ok && e.ProblemType == "urn:ietf:params:acme:error:alreadyRevoked" +} + +// ListCertAlternates retrieves any alternate certificate chain URLs for the +// given certificate chain URL. These alternate URLs can be passed to FetchCert +// in order to retrieve the alternate certificate chains. +// +// If there are no alternate issuer certificate chains, a nil slice will be +// returned. +func (c *Client) ListCertAlternates(ctx context.Context, url string) ([]string, error) { + if _, err := c.Discover(ctx); err != nil { // required by c.accountKID + return nil, err + } + + res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK)) + if err != nil { + return nil, err + } + defer res.Body.Close() + + // We don't need the body but we need to discard it so we don't end up + // preventing keep-alive + if _, err := io.Copy(io.Discard, res.Body); err != nil { + return nil, fmt.Errorf("acme: cert alternates response stream: %v", err) + } + alts := linkHeader(res.Header, "alternate") + return alts, nil +} diff --git a/vendor/github.com/tailscale/golang-x-crypto/acme/types.go b/vendor/github.com/tailscale/golang-x-crypto/acme/types.go new file mode 100644 index 0000000000..9fad800b4a --- /dev/null +++ b/vendor/github.com/tailscale/golang-x-crypto/acme/types.go @@ -0,0 +1,632 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "crypto" + "crypto/x509" + "errors" + "fmt" + "net/http" + "strings" + "time" +) + +// ACME status values of Account, Order, Authorization and Challenge objects. +// See https://tools.ietf.org/html/rfc8555#section-7.1.6 for details. +const ( + StatusDeactivated = "deactivated" + StatusExpired = "expired" + StatusInvalid = "invalid" + StatusPending = "pending" + StatusProcessing = "processing" + StatusReady = "ready" + StatusRevoked = "revoked" + StatusUnknown = "unknown" + StatusValid = "valid" +) + +// CRLReasonCode identifies the reason for a certificate revocation. +type CRLReasonCode int + +// CRL reason codes as defined in RFC 5280. +const ( + CRLReasonUnspecified CRLReasonCode = 0 + CRLReasonKeyCompromise CRLReasonCode = 1 + CRLReasonCACompromise CRLReasonCode = 2 + CRLReasonAffiliationChanged CRLReasonCode = 3 + CRLReasonSuperseded CRLReasonCode = 4 + CRLReasonCessationOfOperation CRLReasonCode = 5 + CRLReasonCertificateHold CRLReasonCode = 6 + CRLReasonRemoveFromCRL CRLReasonCode = 8 + CRLReasonPrivilegeWithdrawn CRLReasonCode = 9 + CRLReasonAACompromise CRLReasonCode = 10 +) + +var ( + // ErrUnsupportedKey is returned when an unsupported key type is encountered. + ErrUnsupportedKey = errors.New("acme: unknown key type; only RSA and ECDSA are supported") + + // ErrAccountAlreadyExists indicates that the Client's key has already been registered + // with the CA. It is returned by Register method. + ErrAccountAlreadyExists = errors.New("acme: account already exists") + + // ErrNoAccount indicates that the Client's key has not been registered with the CA. + ErrNoAccount = errors.New("acme: account does not exist") +) + +// A Subproblem describes an ACME subproblem as reported in an Error. +type Subproblem struct { + // Type is a URI reference that identifies the problem type, + // typically in a "urn:acme:error:xxx" form. + Type string + // Detail is a human-readable explanation specific to this occurrence of the problem. + Detail string + // Instance indicates a URL that the client should direct a human user to visit + // in order for instructions on how to agree to the updated Terms of Service. + // In such an event CA sets StatusCode to 403, Type to + // "urn:ietf:params:acme:error:userActionRequired", and adds a Link header with relation + // "terms-of-service" containing the latest TOS URL. + Instance string + // Identifier may contain the ACME identifier that the error is for. + Identifier *AuthzID +} + +func (sp Subproblem) String() string { + str := fmt.Sprintf("%s: ", sp.Type) + if sp.Identifier != nil { + str += fmt.Sprintf("[%s: %s] ", sp.Identifier.Type, sp.Identifier.Value) + } + str += sp.Detail + return str +} + +// Error is an ACME error, defined in Problem Details for HTTP APIs doc +// http://tools.ietf.org/html/draft-ietf-appsawg-http-problem. +type Error struct { + // StatusCode is The HTTP status code generated by the origin server. + StatusCode int + // ProblemType is a URI reference that identifies the problem type, + // typically in a "urn:acme:error:xxx" form. + ProblemType string + // Detail is a human-readable explanation specific to this occurrence of the problem. + Detail string + // Instance indicates a URL that the client should direct a human user to visit + // in order for instructions on how to agree to the updated Terms of Service. + // In such an event CA sets StatusCode to 403, ProblemType to + // "urn:ietf:params:acme:error:userActionRequired" and a Link header with relation + // "terms-of-service" containing the latest TOS URL. + Instance string + // Header is the original server error response headers. + // It may be nil. + Header http.Header + // Subproblems may contain more detailed information about the individual problems + // that caused the error. This field is only sent by RFC 8555 compatible ACME + // servers. Defined in RFC 8555 Section 6.7.1. + Subproblems []Subproblem +} + +func (e *Error) Error() string { + str := fmt.Sprintf("%d %s: %s", e.StatusCode, e.ProblemType, e.Detail) + if len(e.Subproblems) > 0 { + str += fmt.Sprintf("; subproblems:") + for _, sp := range e.Subproblems { + str += fmt.Sprintf("\n\t%s", sp) + } + } + return str +} + +// AuthorizationError indicates that an authorization for an identifier +// did not succeed. +// It contains all errors from Challenge items of the failed Authorization. +type AuthorizationError struct { + // URI uniquely identifies the failed Authorization. + URI string + + // Identifier is an AuthzID.Value of the failed Authorization. + Identifier string + + // Errors is a collection of non-nil error values of Challenge items + // of the failed Authorization. + Errors []error +} + +func (a *AuthorizationError) Error() string { + e := make([]string, len(a.Errors)) + for i, err := range a.Errors { + e[i] = err.Error() + } + + if a.Identifier != "" { + return fmt.Sprintf("acme: authorization error for %s: %s", a.Identifier, strings.Join(e, "; ")) + } + + return fmt.Sprintf("acme: authorization error: %s", strings.Join(e, "; ")) +} + +// OrderError is returned from Client's order related methods. +// It indicates the order is unusable and the clients should start over with +// AuthorizeOrder. +// +// The clients can still fetch the order object from CA using GetOrder +// to inspect its state. +type OrderError struct { + OrderURL string + Status string +} + +func (oe *OrderError) Error() string { + return fmt.Sprintf("acme: order %s status: %s", oe.OrderURL, oe.Status) +} + +// RateLimit reports whether err represents a rate limit error and +// any Retry-After duration returned by the server. +// +// See the following for more details on rate limiting: +// https://tools.ietf.org/html/draft-ietf-acme-acme-05#section-5.6 +func RateLimit(err error) (time.Duration, bool) { + e, ok := err.(*Error) + if !ok { + return 0, false + } + // Some CA implementations may return incorrect values. + // Use case-insensitive comparison. + if !strings.HasSuffix(strings.ToLower(e.ProblemType), ":ratelimited") { + return 0, false + } + if e.Header == nil { + return 0, true + } + return retryAfter(e.Header.Get("Retry-After")), true +} + +// Account is a user account. It is associated with a private key. +// Non-RFC 8555 fields are empty when interfacing with a compliant CA. +type Account struct { + // URI is the account unique ID, which is also a URL used to retrieve + // account data from the CA. + // When interfacing with RFC 8555-compliant CAs, URI is the "kid" field + // value in JWS signed requests. + URI string + + // Contact is a slice of contact info used during registration. + // See https://tools.ietf.org/html/rfc8555#section-7.3 for supported + // formats. + Contact []string + + // Status indicates current account status as returned by the CA. + // Possible values are StatusValid, StatusDeactivated, and StatusRevoked. + Status string + + // OrdersURL is a URL from which a list of orders submitted by this account + // can be fetched. + OrdersURL string + + // The terms user has agreed to. + // A value not matching CurrentTerms indicates that the user hasn't agreed + // to the actual Terms of Service of the CA. + // + // It is non-RFC 8555 compliant. Package users can store the ToS they agree to + // during Client's Register call in the prompt callback function. + AgreedTerms string + + // Actual terms of a CA. + // + // It is non-RFC 8555 compliant. Use Directory's Terms field. + // When a CA updates their terms and requires an account agreement, + // a URL at which instructions to do so is available in Error's Instance field. + CurrentTerms string + + // Authz is the authorization URL used to initiate a new authz flow. + // + // It is non-RFC 8555 compliant. Use Directory's AuthzURL or OrderURL. + Authz string + + // Authorizations is a URI from which a list of authorizations + // granted to this account can be fetched via a GET request. + // + // It is non-RFC 8555 compliant and is obsoleted by OrdersURL. + Authorizations string + + // Certificates is a URI from which a list of certificates + // issued for this account can be fetched via a GET request. + // + // It is non-RFC 8555 compliant and is obsoleted by OrdersURL. + Certificates string + + // ExternalAccountBinding represents an arbitrary binding to an account of + // the CA which the ACME server is tied to. + // See https://tools.ietf.org/html/rfc8555#section-7.3.4 for more details. + ExternalAccountBinding *ExternalAccountBinding +} + +// ExternalAccountBinding contains the data needed to form a request with +// an external account binding. +// See https://tools.ietf.org/html/rfc8555#section-7.3.4 for more details. +type ExternalAccountBinding struct { + // KID is the Key ID of the symmetric MAC key that the CA provides to + // identify an external account from ACME. + KID string + + // Key is the bytes of the symmetric key that the CA provides to identify + // the account. Key must correspond to the KID. + Key []byte +} + +func (e *ExternalAccountBinding) String() string { + return fmt.Sprintf("&{KID: %q, Key: redacted}", e.KID) +} + +// Directory is ACME server discovery data. +// See https://tools.ietf.org/html/rfc8555#section-7.1.1 for more details. +type Directory struct { + // NonceURL indicates an endpoint where to fetch fresh nonce values from. + NonceURL string + + // RegURL is an account endpoint URL, allowing for creating new accounts. + // Pre-RFC 8555 CAs also allow modifying existing accounts at this URL. + RegURL string + + // OrderURL is used to initiate the certificate issuance flow + // as described in RFC 8555. + OrderURL string + + // AuthzURL is used to initiate identifier pre-authorization flow. + // Empty string indicates the flow is unsupported by the CA. + AuthzURL string + + // CertURL is a new certificate issuance endpoint URL. + // It is non-RFC 8555 compliant and is obsoleted by OrderURL. + CertURL string + + // RevokeURL is used to initiate a certificate revocation flow. + RevokeURL string + + // KeyChangeURL allows to perform account key rollover flow. + KeyChangeURL string + + // RenewalInfoURL allows to perform certificate renewal using the ACME + // Renewal Information (ARI) Extension. + RenewalInfoURL string + + // Term is a URI identifying the current terms of service. + Terms string + + // Website is an HTTP or HTTPS URL locating a website + // providing more information about the ACME server. + Website string + + // CAA consists of lowercase hostname elements, which the ACME server + // recognises as referring to itself for the purposes of CAA record validation + // as defined in RFC 6844. + CAA []string + + // ExternalAccountRequired indicates that the CA requires for all account-related + // requests to include external account binding information. + ExternalAccountRequired bool +} + +// Order represents a client's request for a certificate. +// It tracks the request flow progress through to issuance. +type Order struct { + // URI uniquely identifies an order. + URI string + + // Status represents the current status of the order. + // It indicates which action the client should take. + // + // Possible values are StatusPending, StatusReady, StatusProcessing, StatusValid and StatusInvalid. + // Pending means the CA does not believe that the client has fulfilled the requirements. + // Ready indicates that the client has fulfilled all the requirements and can submit a CSR + // to obtain a certificate. This is done with Client's CreateOrderCert. + // Processing means the certificate is being issued. + // Valid indicates the CA has issued the certificate. It can be downloaded + // from the Order's CertURL. This is done with Client's FetchCert. + // Invalid means the certificate will not be issued. Users should consider this order + // abandoned. + Status string + + // Expires is the timestamp after which CA considers this order invalid. + Expires time.Time + + // Identifiers contains all identifier objects which the order pertains to. + Identifiers []AuthzID + + // NotBefore is the requested value of the notBefore field in the certificate. + NotBefore time.Time + + // NotAfter is the requested value of the notAfter field in the certificate. + NotAfter time.Time + + // AuthzURLs represents authorizations to complete before a certificate + // for identifiers specified in the order can be issued. + // It also contains unexpired authorizations that the client has completed + // in the past. + // + // Authorization objects can be fetched using Client's GetAuthorization method. + // + // The required authorizations are dictated by CA policies. + // There may not be a 1:1 relationship between the identifiers and required authorizations. + // Required authorizations can be identified by their StatusPending status. + // + // For orders in the StatusValid or StatusInvalid state these are the authorizations + // which were completed. + AuthzURLs []string + + // FinalizeURL is the endpoint at which a CSR is submitted to obtain a certificate + // once all the authorizations are satisfied. + FinalizeURL string + + // CertURL points to the certificate that has been issued in response to this order. + CertURL string + + // The error that occurred while processing the order as received from a CA, if any. + Error *Error +} + +// OrderOption allows customizing Client.AuthorizeOrder call. +type OrderOption interface { + privateOrderOpt() +} + +// WithOrderNotBefore sets order's NotBefore field. +func WithOrderNotBefore(t time.Time) OrderOption { + return orderNotBeforeOpt(t) +} + +// WithOrderNotAfter sets order's NotAfter field. +func WithOrderNotAfter(t time.Time) OrderOption { + return orderNotAfterOpt(t) +} + +type orderNotBeforeOpt time.Time + +func (orderNotBeforeOpt) privateOrderOpt() {} + +type orderNotAfterOpt time.Time + +func (orderNotAfterOpt) privateOrderOpt() {} + +// Authorization encodes an authorization response. +type Authorization struct { + // URI uniquely identifies a authorization. + URI string + + // Status is the current status of an authorization. + // Possible values are StatusPending, StatusValid, StatusInvalid, StatusDeactivated, + // StatusExpired and StatusRevoked. + Status string + + // Identifier is what the account is authorized to represent. + Identifier AuthzID + + // The timestamp after which the CA considers the authorization invalid. + Expires time.Time + + // Wildcard is true for authorizations of a wildcard domain name. + Wildcard bool + + // Challenges that the client needs to fulfill in order to prove possession + // of the identifier (for pending authorizations). + // For valid authorizations, the challenge that was validated. + // For invalid authorizations, the challenge that was attempted and failed. + // + // RFC 8555 compatible CAs require users to fuflfill only one of the challenges. + Challenges []*Challenge + + // A collection of sets of challenges, each of which would be sufficient + // to prove possession of the identifier. + // Clients must complete a set of challenges that covers at least one set. + // Challenges are identified by their indices in the challenges array. + // If this field is empty, the client needs to complete all challenges. + // + // This field is unused in RFC 8555. + Combinations [][]int +} + +// AuthzID is an identifier that an account is authorized to represent. +type AuthzID struct { + Type string // The type of identifier, "dns" or "ip". + Value string // The identifier itself, e.g. "example.org". +} + +// DomainIDs creates a slice of AuthzID with "dns" identifier type. +func DomainIDs(names ...string) []AuthzID { + a := make([]AuthzID, len(names)) + for i, v := range names { + a[i] = AuthzID{Type: "dns", Value: v} + } + return a +} + +// IPIDs creates a slice of AuthzID with "ip" identifier type. +// Each element of addr is textual form of an address as defined +// in RFC 1123 Section 2.1 for IPv4 and in RFC 5952 Section 4 for IPv6. +func IPIDs(addr ...string) []AuthzID { + a := make([]AuthzID, len(addr)) + for i, v := range addr { + a[i] = AuthzID{Type: "ip", Value: v} + } + return a +} + +// wireAuthzID is ACME JSON representation of authorization identifier objects. +type wireAuthzID struct { + Type string `json:"type"` + Value string `json:"value"` +} + +// wireAuthz is ACME JSON representation of Authorization objects. +type wireAuthz struct { + Identifier wireAuthzID + Status string + Expires time.Time + Wildcard bool + Challenges []wireChallenge + Combinations [][]int + Error *wireError +} + +func (z *wireAuthz) authorization(uri string) *Authorization { + a := &Authorization{ + URI: uri, + Status: z.Status, + Identifier: AuthzID{Type: z.Identifier.Type, Value: z.Identifier.Value}, + Expires: z.Expires, + Wildcard: z.Wildcard, + Challenges: make([]*Challenge, len(z.Challenges)), + Combinations: z.Combinations, // shallow copy + } + for i, v := range z.Challenges { + a.Challenges[i] = v.challenge() + } + return a +} + +func (z *wireAuthz) error(uri string) *AuthorizationError { + err := &AuthorizationError{ + URI: uri, + Identifier: z.Identifier.Value, + } + + if z.Error != nil { + err.Errors = append(err.Errors, z.Error.error(nil)) + } + + for _, raw := range z.Challenges { + if raw.Error != nil { + err.Errors = append(err.Errors, raw.Error.error(nil)) + } + } + + return err +} + +// Challenge encodes a returned CA challenge. +// Its Error field may be non-nil if the challenge is part of an Authorization +// with StatusInvalid. +type Challenge struct { + // Type is the challenge type, e.g. "http-01", "tls-alpn-01", "dns-01". + Type string + + // URI is where a challenge response can be posted to. + URI string + + // Token is a random value that uniquely identifies the challenge. + Token string + + // Status identifies the status of this challenge. + // In RFC 8555, possible values are StatusPending, StatusProcessing, StatusValid, + // and StatusInvalid. + Status string + + // Validated is the time at which the CA validated this challenge. + // Always zero value in pre-RFC 8555. + Validated time.Time + + // Error indicates the reason for an authorization failure + // when this challenge was used. + // The type of a non-nil value is *Error. + Error error +} + +// wireChallenge is ACME JSON challenge representation. +type wireChallenge struct { + URL string `json:"url"` // RFC + URI string `json:"uri"` // pre-RFC + Type string + Token string + Status string + Validated time.Time + Error *wireError +} + +func (c *wireChallenge) challenge() *Challenge { + v := &Challenge{ + URI: c.URL, + Type: c.Type, + Token: c.Token, + Status: c.Status, + } + if v.URI == "" { + v.URI = c.URI // c.URL was empty; use legacy + } + if v.Status == "" { + v.Status = StatusPending + } + if c.Error != nil { + v.Error = c.Error.error(nil) + } + return v +} + +// wireError is a subset of fields of the Problem Details object +// as described in https://tools.ietf.org/html/rfc7807#section-3.1. +type wireError struct { + Status int + Type string + Detail string + Instance string + Subproblems []Subproblem +} + +func (e *wireError) error(h http.Header) *Error { + err := &Error{ + StatusCode: e.Status, + ProblemType: e.Type, + Detail: e.Detail, + Instance: e.Instance, + Header: h, + Subproblems: e.Subproblems, + } + return err +} + +// CertOption is an optional argument type for the TLS ChallengeCert methods for +// customizing a temporary certificate for TLS-based challenges. +type CertOption interface { + privateCertOpt() +} + +// WithKey creates an option holding a private/public key pair. +// The private part signs a certificate, and the public part represents the signee. +func WithKey(key crypto.Signer) CertOption { + return &certOptKey{key} +} + +type certOptKey struct { + key crypto.Signer +} + +func (*certOptKey) privateCertOpt() {} + +// WithTemplate creates an option for specifying a certificate template. +// See x509.CreateCertificate for template usage details. +// +// In TLS ChallengeCert methods, the template is also used as parent, +// resulting in a self-signed certificate. +// The DNSNames field of t is always overwritten for tls-sni challenge certs. +func WithTemplate(t *x509.Certificate) CertOption { + return (*certOptTemplate)(t) +} + +type certOptTemplate x509.Certificate + +func (*certOptTemplate) privateCertOpt() {} + +// RenewalInfoWindow describes the time frame during which the ACME client +// should attempt to renew, using the ACME Renewal Info Extension. +type RenewalInfoWindow struct { + Start time.Time `json:"start"` + End time.Time `json:"end"` +} + +// RenewalInfo describes the suggested renewal window for a given certificate, +// returned from an ACME server, using the ACME Renewal Info Extension. +type RenewalInfo struct { + SuggestedWindow RenewalInfoWindow `json:"suggestedWindow"` + ExplanationURL string `json:"explanationURL"` +} diff --git a/vendor/github.com/tailscale/golang-x-crypto/acme/version_go112.go b/vendor/github.com/tailscale/golang-x-crypto/acme/version_go112.go new file mode 100644 index 0000000000..b9efdb59e5 --- /dev/null +++ b/vendor/github.com/tailscale/golang-x-crypto/acme/version_go112.go @@ -0,0 +1,28 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.12 +// +build go1.12 + +package acme + +import "runtime/debug" + +func init() { + // Set packageVersion if the binary was built in modules mode and x/crypto + // was not replaced with a different module. + info, ok := debug.ReadBuildInfo() + if !ok { + return + } + for _, m := range info.Deps { + if m.Path != "golang.org/x/crypto" { + continue + } + if m.Replace == nil { + packageVersion = m.Version + } + break + } +} diff --git a/vendor/github.com/tailscale/golang-x-crypto/ssh/cipher.go b/vendor/github.com/tailscale/golang-x-crypto/ssh/cipher.go index 1c380f4607..951812ff9b 100644 --- a/vendor/github.com/tailscale/golang-x-crypto/ssh/cipher.go +++ b/vendor/github.com/tailscale/golang-x-crypto/ssh/cipher.go @@ -114,7 +114,8 @@ var cipherModes = map[string]*cipherMode{ "arcfour": {16, 0, streamCipherMode(0, newRC4)}, // AEAD ciphers - gcmCipherID: {16, 12, newGCMCipher}, + gcm128CipherID: {16, 12, newGCMCipher}, + gcm256CipherID: {32, 12, newGCMCipher}, chacha20Poly1305ID: {64, 0, newChaCha20Cipher}, // CBC mode is insecure and so is not included in the default config. diff --git a/vendor/github.com/tailscale/golang-x-crypto/ssh/common.go b/vendor/github.com/tailscale/golang-x-crypto/ssh/common.go index c7964275de..5ce452bae8 100644 --- a/vendor/github.com/tailscale/golang-x-crypto/ssh/common.go +++ b/vendor/github.com/tailscale/golang-x-crypto/ssh/common.go @@ -28,7 +28,7 @@ const ( // supportedCiphers lists ciphers we support but might not recommend. var supportedCiphers = []string{ "aes128-ctr", "aes192-ctr", "aes256-ctr", - "aes128-gcm@openssh.com", + "aes128-gcm@openssh.com", gcm256CipherID, chacha20Poly1305ID, "arcfour256", "arcfour128", "arcfour", aes128cbcID, @@ -37,7 +37,7 @@ var supportedCiphers = []string{ // preferredCiphers specifies the default preference for ciphers. var preferredCiphers = []string{ - "aes128-gcm@openssh.com", + "aes128-gcm@openssh.com", gcm256CipherID, chacha20Poly1305ID, "aes128-ctr", "aes192-ctr", "aes256-ctr", } @@ -85,7 +85,7 @@ var supportedHostKeyAlgos = []string{ // This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed // because they have reached the end of their useful life. var supportedMACs = []string{ - "hmac-sha2-256-etm@openssh.com", "hmac-sha2-256", "hmac-sha1", "hmac-sha1-96", + "hmac-sha2-256-etm@openssh.com", "hmac-sha2-512-etm@openssh.com", "hmac-sha2-256", "hmac-sha2-512", "hmac-sha1", "hmac-sha1-96", } var supportedCompressions = []string{compressionNone} @@ -119,6 +119,13 @@ func algorithmsForKeyFormat(keyFormat string) []string { } } +// isRSA returns whether algo is a supported RSA algorithm, including certificate +// algorithms. +func isRSA(algo string) bool { + algos := algorithmsForKeyFormat(KeyAlgoRSA) + return contains(algos, underlyingAlgo(algo)) +} + // supportedPubKeyAuthAlgos specifies the supported client public key // authentication algorithms. Note that this doesn't include certificate types // since those use the underlying algorithm. This list is sent to the client if @@ -168,7 +175,7 @@ func (a *directionAlgorithms) rekeyBytes() int64 { // 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is // 128. switch a.Cipher { - case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcmCipherID, aes128cbcID: + case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcm128CipherID, gcm256CipherID, aes128cbcID: return 16 * (1 << 32) } @@ -178,7 +185,8 @@ func (a *directionAlgorithms) rekeyBytes() int64 { } var aeadCiphers = map[string]bool{ - gcmCipherID: true, + gcm128CipherID: true, + gcm256CipherID: true, chacha20Poly1305ID: true, } diff --git a/vendor/github.com/tailscale/golang-x-crypto/ssh/connection.go b/vendor/github.com/tailscale/golang-x-crypto/ssh/connection.go index 6f89ab6f34..4116fb2988 100644 --- a/vendor/github.com/tailscale/golang-x-crypto/ssh/connection.go +++ b/vendor/github.com/tailscale/golang-x-crypto/ssh/connection.go @@ -102,7 +102,7 @@ func (c *connection) Close() error { return c.sshConn.conn.Close() } -// sshconn provides net.Conn metadata, but disallows direct reads and +// sshConn provides net.Conn metadata, but disallows direct reads and // writes. type sshConn struct { conn net.Conn diff --git a/vendor/github.com/tailscale/golang-x-crypto/ssh/handshake.go b/vendor/github.com/tailscale/golang-x-crypto/ssh/handshake.go index 2b84c35716..07a1843e0a 100644 --- a/vendor/github.com/tailscale/golang-x-crypto/ssh/handshake.go +++ b/vendor/github.com/tailscale/golang-x-crypto/ssh/handshake.go @@ -58,11 +58,13 @@ type handshakeTransport struct { incoming chan []byte readError error - mu sync.Mutex - writeError error - sentInitPacket []byte - sentInitMsg *kexInitMsg - pendingPackets [][]byte // Used when a key exchange is in progress. + mu sync.Mutex + writeError error + sentInitPacket []byte + sentInitMsg *kexInitMsg + pendingPackets [][]byte // Used when a key exchange is in progress. + writePacketsLeft uint32 + writeBytesLeft int64 // If the read loop wants to schedule a kex, it pings this // channel, and the write loop will send out a kex @@ -71,7 +73,8 @@ type handshakeTransport struct { // If the other side requests or confirms a kex, its kexInit // packet is sent here for the write loop to find it. - startKex chan *pendingKex + startKex chan *pendingKex + kexLoopDone chan struct{} // closed (with writeError non-nil) when kexLoop exits // data for host key checking hostKeyCallback HostKeyCallback @@ -86,12 +89,10 @@ type handshakeTransport struct { // Algorithms agreed in the last key exchange. algorithms *algorithms + // Counters exclusively owned by readLoop. readPacketsLeft uint32 readBytesLeft int64 - writePacketsLeft uint32 - writeBytesLeft int64 - // The session ID or nil if first kex did not complete yet. sessionID []byte } @@ -108,7 +109,8 @@ func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, clientVersion: clientVersion, incoming: make(chan []byte, chanSize), requestKex: make(chan struct{}, 1), - startKex: make(chan *pendingKex, 1), + startKex: make(chan *pendingKex), + kexLoopDone: make(chan struct{}), config: config, } @@ -340,16 +342,17 @@ write: t.mu.Unlock() } + // Unblock reader. + t.conn.Close() + // drain startKex channel. We don't service t.requestKex // because nobody does blocking sends there. - go func() { - for init := range t.startKex { - init.done <- t.writeError - } - }() + for request := range t.startKex { + request.done <- t.getWriteError() + } - // Unblock reader. - t.conn.Close() + // Mark that the loop is done so that Close can return. + close(t.kexLoopDone) } // The protocol uses uint32 for packet counters, so we can't let them @@ -545,7 +548,16 @@ func (t *handshakeTransport) writePacket(p []byte) error { } func (t *handshakeTransport) Close() error { - return t.conn.Close() + // Close the connection. This should cause the readLoop goroutine to wake up + // and close t.startKex, which will shut down kexLoop if running. + err := t.conn.Close() + + // Wait for the kexLoop goroutine to complete. + // At that point we know that the readLoop goroutine is complete too, + // because kexLoop itself waits for readLoop to close the startKex channel. + <-t.kexLoopDone + + return err } func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { diff --git a/vendor/github.com/tailscale/golang-x-crypto/ssh/keys.go b/vendor/github.com/tailscale/golang-x-crypto/ssh/keys.go index c1e7e8cbcd..28f3446e4c 100644 --- a/vendor/github.com/tailscale/golang-x-crypto/ssh/keys.go +++ b/vendor/github.com/tailscale/golang-x-crypto/ssh/keys.go @@ -1087,9 +1087,9 @@ func (*PassphraseMissingError) Error() string { return "ssh: this private key is passphrase protected" } -// ParseRawPrivateKey returns a private key from a PEM encoded private key. It -// supports RSA (PKCS#1), PKCS#8, DSA (OpenSSL), and ECDSA private keys. If the -// private key is encrypted, it will return a PassphraseMissingError. +// ParseRawPrivateKey returns a private key from a PEM encoded private key. It supports +// RSA, DSA, ECDSA, and Ed25519 private keys in PKCS#1, PKCS#8, OpenSSL, and OpenSSH +// formats. If the private key is encrypted, it will return a PassphraseMissingError. func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) { block, _ := pem.Decode(pemBytes) if block == nil { diff --git a/vendor/github.com/tailscale/golang-x-crypto/ssh/mac.go b/vendor/github.com/tailscale/golang-x-crypto/ssh/mac.go index c07a06285e..06a1b27507 100644 --- a/vendor/github.com/tailscale/golang-x-crypto/ssh/mac.go +++ b/vendor/github.com/tailscale/golang-x-crypto/ssh/mac.go @@ -10,6 +10,7 @@ import ( "crypto/hmac" "crypto/sha1" "crypto/sha256" + "crypto/sha512" "hash" ) @@ -46,9 +47,15 @@ func (t truncatingMAC) Size() int { func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() } var macModes = map[string]*macMode{ + "hmac-sha2-512-etm@openssh.com": {64, true, func(key []byte) hash.Hash { + return hmac.New(sha512.New, key) + }}, "hmac-sha2-256-etm@openssh.com": {32, true, func(key []byte) hash.Hash { return hmac.New(sha256.New, key) }}, + "hmac-sha2-512": {64, false, func(key []byte) hash.Hash { + return hmac.New(sha512.New, key) + }}, "hmac-sha2-256": {32, false, func(key []byte) hash.Hash { return hmac.New(sha256.New, key) }}, diff --git a/vendor/github.com/tailscale/golang-x-crypto/ssh/server.go b/vendor/github.com/tailscale/golang-x-crypto/ssh/server.go index 24ba0a6546..d86617440e 100644 --- a/vendor/github.com/tailscale/golang-x-crypto/ssh/server.go +++ b/vendor/github.com/tailscale/golang-x-crypto/ssh/server.go @@ -382,6 +382,25 @@ func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, firstToken []byte, s *c return authErr, perms, nil } +// isAlgoCompatible checks if the signature format is compatible with the +// selected algorithm taking into account edge cases that occur with old +// clients. +func isAlgoCompatible(algo, sigFormat string) bool { + // Compatibility for old clients. + // + // For certificate authentication with OpenSSH 7.2-7.7 signature format can + // be rsa-sha2-256 or rsa-sha2-512 for the algorithm + // ssh-rsa-cert-v01@openssh.com. + // + // With gpg-agent < 2.2.6 the algorithm can be rsa-sha2-256 or rsa-sha2-512 + // for signature format ssh-rsa. + if isRSA(algo) && isRSA(sigFormat) { + return true + } + // Standard case: the underlying algorithm must match the signature format. + return underlyingAlgo(algo) == sigFormat +} + // ServerAuthError represents server authentication errors and is // sometimes returned by NewServerConn. It appends any authentication // errors that may occur, and is returned if all of the authentication @@ -587,7 +606,7 @@ userAuthLoop: authErr = fmt.Errorf("ssh: algorithm %q not accepted", sig.Format) break } - if underlyingAlgo(algo) != sig.Format { + if !isAlgoCompatible(algo, sig.Format) { authErr = fmt.Errorf("ssh: signature %q not compatible with selected algorithm %q", sig.Format, algo) break } diff --git a/vendor/github.com/tailscale/golang-x-crypto/ssh/transport.go b/vendor/github.com/tailscale/golang-x-crypto/ssh/transport.go index acf5a21bbb..da015801ea 100644 --- a/vendor/github.com/tailscale/golang-x-crypto/ssh/transport.go +++ b/vendor/github.com/tailscale/golang-x-crypto/ssh/transport.go @@ -17,7 +17,8 @@ import ( const debugTransport = false const ( - gcmCipherID = "aes128-gcm@openssh.com" + gcm128CipherID = "aes128-gcm@openssh.com" + gcm256CipherID = "aes256-gcm@openssh.com" aes128cbcID = "aes128-cbc" tripledescbcID = "3des-cbc" ) diff --git a/vendor/github.com/tailscale/wireguard-go/conn/bind_std.go b/vendor/github.com/tailscale/wireguard-go/conn/bind_std.go index c399e8b35d..0a9ff38f9c 100644 --- a/vendor/github.com/tailscale/wireguard-go/conn/bind_std.go +++ b/vendor/github.com/tailscale/wireguard-go/conn/bind_std.go @@ -144,6 +144,11 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) { return conn.(*net.UDPConn), uaddr.Port, nil } +// errEADDRINUSE is syscall.EADDRINUSE, boxed into an interface once +// in erraddrinuse.go on almost all platforms. For other platforms, +// it's at least non-nil. +var errEADDRINUSE error = errors.New("") + func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { s.mu.Lock() defer s.mu.Unlock() @@ -170,7 +175,7 @@ again: // Listen on the same port as we're using for ipv4. v6conn, port, err = listenNet("udp6", port) - if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { + if uport == 0 && errors.Is(err, errEADDRINUSE) && tries < 100 { v4conn.Close() tries++ goto again diff --git a/vendor/github.com/tailscale/wireguard-go/conn/controlfns_unix.go b/vendor/github.com/tailscale/wireguard-go/conn/controlfns_unix.go index c4536d4bb4..5cc4d98f93 100644 --- a/vendor/github.com/tailscale/wireguard-go/conn/controlfns_unix.go +++ b/vendor/github.com/tailscale/wireguard-go/conn/controlfns_unix.go @@ -1,4 +1,4 @@ -//go:build !windows && !linux && !js +//go:build !windows && !linux && !wasm && !plan9 /* SPDX-License-Identifier: MIT * diff --git a/vendor/github.com/tailscale/wireguard-go/conn/erraddrinuse.go b/vendor/github.com/tailscale/wireguard-go/conn/erraddrinuse.go new file mode 100644 index 0000000000..c01a0a17f9 --- /dev/null +++ b/vendor/github.com/tailscale/wireguard-go/conn/erraddrinuse.go @@ -0,0 +1,14 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +//go:build !plan9 + +package conn + +import "syscall" + +func init() { + errEADDRINUSE = syscall.EADDRINUSE +} diff --git a/vendor/github.com/tailscale/wireguard-go/ipc/uapi_js.go b/vendor/github.com/tailscale/wireguard-go/ipc/uapi_fake.go similarity index 72% rename from vendor/github.com/tailscale/wireguard-go/ipc/uapi_js.go rename to vendor/github.com/tailscale/wireguard-go/ipc/uapi_fake.go index 2570515e20..346f49e335 100644 --- a/vendor/github.com/tailscale/wireguard-go/ipc/uapi_js.go +++ b/vendor/github.com/tailscale/wireguard-go/ipc/uapi_fake.go @@ -3,9 +3,11 @@ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ +//go:build wasm || plan9 + package ipc -// Made up sentinel error codes for the js/wasm platform. +// Made up sentinel error codes for {js,wasip1}/wasm, and plan9. const ( IpcErrorIO = 1 IpcErrorInvalid = 2 diff --git a/vendor/github.com/tailscale/wireguard-go/rwcancel/rwcancel.go b/vendor/github.com/tailscale/wireguard-go/rwcancel/rwcancel.go index 63e1510b10..dd649d49ff 100644 --- a/vendor/github.com/tailscale/wireguard-go/rwcancel/rwcancel.go +++ b/vendor/github.com/tailscale/wireguard-go/rwcancel/rwcancel.go @@ -1,4 +1,4 @@ -//go:build !windows && !js +//go:build !windows && !wasm && !plan9 /* SPDX-License-Identifier: MIT * diff --git a/vendor/github.com/tailscale/wireguard-go/rwcancel/rwcancel_stub.go b/vendor/github.com/tailscale/wireguard-go/rwcancel/rwcancel_stub.go index 182940b32e..46238014c0 100644 --- a/vendor/github.com/tailscale/wireguard-go/rwcancel/rwcancel_stub.go +++ b/vendor/github.com/tailscale/wireguard-go/rwcancel/rwcancel_stub.go @@ -1,4 +1,4 @@ -//go:build windows || js +//go:build windows || wasm || plan9 // SPDX-License-Identifier: MIT diff --git a/vendor/github.com/tailscale/wireguard-go/tun/checksum.go b/vendor/github.com/tailscale/wireguard-go/tun/checksum.go index 29a8fc8fc0..ee3f359605 100644 --- a/vendor/github.com/tailscale/wireguard-go/tun/checksum.go +++ b/vendor/github.com/tailscale/wireguard-go/tun/checksum.go @@ -1,118 +1,710 @@ package tun -import "encoding/binary" +import ( + "encoding/binary" + "math/bits" + "strconv" -// TODO: Explore SIMD and/or other assembly optimizations. -// TODO: Test native endian loads. See RFC 1071 section 2 part B. -func checksumNoFold(b []byte, initial uint64) uint64 { - ac := initial + "golang.org/x/sys/cpu" +) + +// checksumGeneric64 is a reference implementation of checksum using 64 bit +// arithmetic for use in testing or when an architecture-specific implementation +// is not available. +func checksumGeneric64(b []byte, initial uint16) uint16 { + var ac uint64 + var carry uint64 + + if cpu.IsBigEndian { + ac = uint64(initial) + } else { + ac = uint64(bits.ReverseBytes16(initial)) + } for len(b) >= 128 { - ac += uint64(binary.BigEndian.Uint32(b[:4])) - ac += uint64(binary.BigEndian.Uint32(b[4:8])) - ac += uint64(binary.BigEndian.Uint32(b[8:12])) - ac += uint64(binary.BigEndian.Uint32(b[12:16])) - ac += uint64(binary.BigEndian.Uint32(b[16:20])) - ac += uint64(binary.BigEndian.Uint32(b[20:24])) - ac += uint64(binary.BigEndian.Uint32(b[24:28])) - ac += uint64(binary.BigEndian.Uint32(b[28:32])) - ac += uint64(binary.BigEndian.Uint32(b[32:36])) - ac += uint64(binary.BigEndian.Uint32(b[36:40])) - ac += uint64(binary.BigEndian.Uint32(b[40:44])) - ac += uint64(binary.BigEndian.Uint32(b[44:48])) - ac += uint64(binary.BigEndian.Uint32(b[48:52])) - ac += uint64(binary.BigEndian.Uint32(b[52:56])) - ac += uint64(binary.BigEndian.Uint32(b[56:60])) - ac += uint64(binary.BigEndian.Uint32(b[60:64])) - ac += uint64(binary.BigEndian.Uint32(b[64:68])) - ac += uint64(binary.BigEndian.Uint32(b[68:72])) - ac += uint64(binary.BigEndian.Uint32(b[72:76])) - ac += uint64(binary.BigEndian.Uint32(b[76:80])) - ac += uint64(binary.BigEndian.Uint32(b[80:84])) - ac += uint64(binary.BigEndian.Uint32(b[84:88])) - ac += uint64(binary.BigEndian.Uint32(b[88:92])) - ac += uint64(binary.BigEndian.Uint32(b[92:96])) - ac += uint64(binary.BigEndian.Uint32(b[96:100])) - ac += uint64(binary.BigEndian.Uint32(b[100:104])) - ac += uint64(binary.BigEndian.Uint32(b[104:108])) - ac += uint64(binary.BigEndian.Uint32(b[108:112])) - ac += uint64(binary.BigEndian.Uint32(b[112:116])) - ac += uint64(binary.BigEndian.Uint32(b[116:120])) - ac += uint64(binary.BigEndian.Uint32(b[120:124])) - ac += uint64(binary.BigEndian.Uint32(b[124:128])) + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[64:72]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[72:80]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[80:88]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[88:96]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[96:104]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[104:112]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[112:120]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[120:128]), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[64:72]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[72:80]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[80:88]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[88:96]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[96:104]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[104:112]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[112:120]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[120:128]), carry) + } b = b[128:] } if len(b) >= 64 { - ac += uint64(binary.BigEndian.Uint32(b[:4])) - ac += uint64(binary.BigEndian.Uint32(b[4:8])) - ac += uint64(binary.BigEndian.Uint32(b[8:12])) - ac += uint64(binary.BigEndian.Uint32(b[12:16])) - ac += uint64(binary.BigEndian.Uint32(b[16:20])) - ac += uint64(binary.BigEndian.Uint32(b[20:24])) - ac += uint64(binary.BigEndian.Uint32(b[24:28])) - ac += uint64(binary.BigEndian.Uint32(b[28:32])) - ac += uint64(binary.BigEndian.Uint32(b[32:36])) - ac += uint64(binary.BigEndian.Uint32(b[36:40])) - ac += uint64(binary.BigEndian.Uint32(b[40:44])) - ac += uint64(binary.BigEndian.Uint32(b[44:48])) - ac += uint64(binary.BigEndian.Uint32(b[48:52])) - ac += uint64(binary.BigEndian.Uint32(b[52:56])) - ac += uint64(binary.BigEndian.Uint32(b[56:60])) - ac += uint64(binary.BigEndian.Uint32(b[60:64])) + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry) + } + b = b[64:] + } + if len(b) >= 32 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry) + } + b = b[32:] + } + if len(b) >= 16 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry) + } + b = b[16:] + } + if len(b) >= 8 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b), carry) + } + b = b[8:] + } + if len(b) >= 4 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint32(b)), carry) + } else { + ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint32(b)), carry) + } + b = b[4:] + } + if len(b) >= 2 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint16(b)), carry) + } else { + ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint16(b)), carry) + } + b = b[2:] + } + if len(b) >= 1 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, uint64(b[0])<<8, carry) + } else { + ac, carry = bits.Add64(ac, uint64(b[0]), carry) + } + } + + folded := ipChecksumFold64(ac, carry) + if !cpu.IsBigEndian { + folded = bits.ReverseBytes16(folded) + } + return folded +} + +// checksumGeneric32 is a reference implementation of checksum using 32 bit +// arithmetic for use in testing or when an architecture-specific implementation +// is not available. +func checksumGeneric32(b []byte, initial uint16) uint16 { + var ac uint32 + var carry uint32 + + if cpu.IsBigEndian { + ac = uint32(initial) + } else { + ac = uint32(bits.ReverseBytes16(initial)) + } + + for len(b) >= 64 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:8]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[32:36]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[36:40]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[40:44]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[44:48]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[48:52]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[52:56]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[56:60]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[60:64]), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:8]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[32:36]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[36:40]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[40:44]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[44:48]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[48:52]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[52:56]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[56:60]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[60:64]), carry) + } + b = b[64:] + } + if len(b) >= 32 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry) + } + b = b[32:] + } + if len(b) >= 16 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry) + } + b = b[16:] + } + if len(b) >= 8 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry) + } + b = b[8:] + } + if len(b) >= 4 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b), carry) + } + b = b[4:] + } + if len(b) >= 2 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, uint32(binary.BigEndian.Uint16(b)), carry) + } else { + ac, carry = bits.Add32(ac, uint32(binary.LittleEndian.Uint16(b)), carry) + } + b = b[2:] + } + if len(b) >= 1 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, uint32(b[0])<<8, carry) + } else { + ac, carry = bits.Add32(ac, uint32(b[0]), carry) + } + } + + folded := ipChecksumFold32(ac, carry) + if !cpu.IsBigEndian { + folded = bits.ReverseBytes16(folded) + } + return folded +} + +// checksumGeneric32Alternate is an alternate reference implementation of +// checksum using 32 bit arithmetic for use in testing or when an +// architecture-specific implementation is not available. +func checksumGeneric32Alternate(b []byte, initial uint16) uint16 { + var ac uint32 + + if cpu.IsBigEndian { + ac = uint32(initial) + } else { + ac = uint32(bits.ReverseBytes16(initial)) + } + + for len(b) >= 64 { + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + ac += uint32(binary.BigEndian.Uint16(b[4:6])) + ac += uint32(binary.BigEndian.Uint16(b[6:8])) + ac += uint32(binary.BigEndian.Uint16(b[8:10])) + ac += uint32(binary.BigEndian.Uint16(b[10:12])) + ac += uint32(binary.BigEndian.Uint16(b[12:14])) + ac += uint32(binary.BigEndian.Uint16(b[14:16])) + ac += uint32(binary.BigEndian.Uint16(b[16:18])) + ac += uint32(binary.BigEndian.Uint16(b[18:20])) + ac += uint32(binary.BigEndian.Uint16(b[20:22])) + ac += uint32(binary.BigEndian.Uint16(b[22:24])) + ac += uint32(binary.BigEndian.Uint16(b[24:26])) + ac += uint32(binary.BigEndian.Uint16(b[26:28])) + ac += uint32(binary.BigEndian.Uint16(b[28:30])) + ac += uint32(binary.BigEndian.Uint16(b[30:32])) + ac += uint32(binary.BigEndian.Uint16(b[32:34])) + ac += uint32(binary.BigEndian.Uint16(b[34:36])) + ac += uint32(binary.BigEndian.Uint16(b[36:38])) + ac += uint32(binary.BigEndian.Uint16(b[38:40])) + ac += uint32(binary.BigEndian.Uint16(b[40:42])) + ac += uint32(binary.BigEndian.Uint16(b[42:44])) + ac += uint32(binary.BigEndian.Uint16(b[44:46])) + ac += uint32(binary.BigEndian.Uint16(b[46:48])) + ac += uint32(binary.BigEndian.Uint16(b[48:50])) + ac += uint32(binary.BigEndian.Uint16(b[50:52])) + ac += uint32(binary.BigEndian.Uint16(b[52:54])) + ac += uint32(binary.BigEndian.Uint16(b[54:56])) + ac += uint32(binary.BigEndian.Uint16(b[56:58])) + ac += uint32(binary.BigEndian.Uint16(b[58:60])) + ac += uint32(binary.BigEndian.Uint16(b[60:62])) + ac += uint32(binary.BigEndian.Uint16(b[62:64])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + ac += uint32(binary.LittleEndian.Uint16(b[4:6])) + ac += uint32(binary.LittleEndian.Uint16(b[6:8])) + ac += uint32(binary.LittleEndian.Uint16(b[8:10])) + ac += uint32(binary.LittleEndian.Uint16(b[10:12])) + ac += uint32(binary.LittleEndian.Uint16(b[12:14])) + ac += uint32(binary.LittleEndian.Uint16(b[14:16])) + ac += uint32(binary.LittleEndian.Uint16(b[16:18])) + ac += uint32(binary.LittleEndian.Uint16(b[18:20])) + ac += uint32(binary.LittleEndian.Uint16(b[20:22])) + ac += uint32(binary.LittleEndian.Uint16(b[22:24])) + ac += uint32(binary.LittleEndian.Uint16(b[24:26])) + ac += uint32(binary.LittleEndian.Uint16(b[26:28])) + ac += uint32(binary.LittleEndian.Uint16(b[28:30])) + ac += uint32(binary.LittleEndian.Uint16(b[30:32])) + ac += uint32(binary.LittleEndian.Uint16(b[32:34])) + ac += uint32(binary.LittleEndian.Uint16(b[34:36])) + ac += uint32(binary.LittleEndian.Uint16(b[36:38])) + ac += uint32(binary.LittleEndian.Uint16(b[38:40])) + ac += uint32(binary.LittleEndian.Uint16(b[40:42])) + ac += uint32(binary.LittleEndian.Uint16(b[42:44])) + ac += uint32(binary.LittleEndian.Uint16(b[44:46])) + ac += uint32(binary.LittleEndian.Uint16(b[46:48])) + ac += uint32(binary.LittleEndian.Uint16(b[48:50])) + ac += uint32(binary.LittleEndian.Uint16(b[50:52])) + ac += uint32(binary.LittleEndian.Uint16(b[52:54])) + ac += uint32(binary.LittleEndian.Uint16(b[54:56])) + ac += uint32(binary.LittleEndian.Uint16(b[56:58])) + ac += uint32(binary.LittleEndian.Uint16(b[58:60])) + ac += uint32(binary.LittleEndian.Uint16(b[60:62])) + ac += uint32(binary.LittleEndian.Uint16(b[62:64])) + } b = b[64:] } if len(b) >= 32 { - ac += uint64(binary.BigEndian.Uint32(b[:4])) - ac += uint64(binary.BigEndian.Uint32(b[4:8])) - ac += uint64(binary.BigEndian.Uint32(b[8:12])) - ac += uint64(binary.BigEndian.Uint32(b[12:16])) - ac += uint64(binary.BigEndian.Uint32(b[16:20])) - ac += uint64(binary.BigEndian.Uint32(b[20:24])) - ac += uint64(binary.BigEndian.Uint32(b[24:28])) - ac += uint64(binary.BigEndian.Uint32(b[28:32])) + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + ac += uint32(binary.BigEndian.Uint16(b[4:6])) + ac += uint32(binary.BigEndian.Uint16(b[6:8])) + ac += uint32(binary.BigEndian.Uint16(b[8:10])) + ac += uint32(binary.BigEndian.Uint16(b[10:12])) + ac += uint32(binary.BigEndian.Uint16(b[12:14])) + ac += uint32(binary.BigEndian.Uint16(b[14:16])) + ac += uint32(binary.BigEndian.Uint16(b[16:18])) + ac += uint32(binary.BigEndian.Uint16(b[18:20])) + ac += uint32(binary.BigEndian.Uint16(b[20:22])) + ac += uint32(binary.BigEndian.Uint16(b[22:24])) + ac += uint32(binary.BigEndian.Uint16(b[24:26])) + ac += uint32(binary.BigEndian.Uint16(b[26:28])) + ac += uint32(binary.BigEndian.Uint16(b[28:30])) + ac += uint32(binary.BigEndian.Uint16(b[30:32])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + ac += uint32(binary.LittleEndian.Uint16(b[4:6])) + ac += uint32(binary.LittleEndian.Uint16(b[6:8])) + ac += uint32(binary.LittleEndian.Uint16(b[8:10])) + ac += uint32(binary.LittleEndian.Uint16(b[10:12])) + ac += uint32(binary.LittleEndian.Uint16(b[12:14])) + ac += uint32(binary.LittleEndian.Uint16(b[14:16])) + ac += uint32(binary.LittleEndian.Uint16(b[16:18])) + ac += uint32(binary.LittleEndian.Uint16(b[18:20])) + ac += uint32(binary.LittleEndian.Uint16(b[20:22])) + ac += uint32(binary.LittleEndian.Uint16(b[22:24])) + ac += uint32(binary.LittleEndian.Uint16(b[24:26])) + ac += uint32(binary.LittleEndian.Uint16(b[26:28])) + ac += uint32(binary.LittleEndian.Uint16(b[28:30])) + ac += uint32(binary.LittleEndian.Uint16(b[30:32])) + } b = b[32:] } if len(b) >= 16 { - ac += uint64(binary.BigEndian.Uint32(b[:4])) - ac += uint64(binary.BigEndian.Uint32(b[4:8])) - ac += uint64(binary.BigEndian.Uint32(b[8:12])) - ac += uint64(binary.BigEndian.Uint32(b[12:16])) + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + ac += uint32(binary.BigEndian.Uint16(b[4:6])) + ac += uint32(binary.BigEndian.Uint16(b[6:8])) + ac += uint32(binary.BigEndian.Uint16(b[8:10])) + ac += uint32(binary.BigEndian.Uint16(b[10:12])) + ac += uint32(binary.BigEndian.Uint16(b[12:14])) + ac += uint32(binary.BigEndian.Uint16(b[14:16])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + ac += uint32(binary.LittleEndian.Uint16(b[4:6])) + ac += uint32(binary.LittleEndian.Uint16(b[6:8])) + ac += uint32(binary.LittleEndian.Uint16(b[8:10])) + ac += uint32(binary.LittleEndian.Uint16(b[10:12])) + ac += uint32(binary.LittleEndian.Uint16(b[12:14])) + ac += uint32(binary.LittleEndian.Uint16(b[14:16])) + } b = b[16:] } if len(b) >= 8 { - ac += uint64(binary.BigEndian.Uint32(b[:4])) - ac += uint64(binary.BigEndian.Uint32(b[4:8])) + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + ac += uint32(binary.BigEndian.Uint16(b[4:6])) + ac += uint32(binary.BigEndian.Uint16(b[6:8])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + ac += uint32(binary.LittleEndian.Uint16(b[4:6])) + ac += uint32(binary.LittleEndian.Uint16(b[6:8])) + } b = b[8:] } if len(b) >= 4 { - ac += uint64(binary.BigEndian.Uint32(b)) + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + } b = b[4:] } if len(b) >= 2 { - ac += uint64(binary.BigEndian.Uint16(b)) + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b)) + } else { + ac += uint32(binary.LittleEndian.Uint16(b)) + } b = b[2:] } - if len(b) == 1 { - ac += uint64(b[0]) << 8 + if len(b) >= 1 { + if cpu.IsBigEndian { + ac += uint32(b[0]) << 8 + } else { + ac += uint32(b[0]) + } } - return ac + folded := ipChecksumFold32(ac, 0) + if !cpu.IsBigEndian { + folded = bits.ReverseBytes16(folded) + } + return folded } -func checksum(b []byte, initial uint64) uint16 { - ac := checksumNoFold(b, initial) - ac = (ac >> 16) + (ac & 0xffff) - ac = (ac >> 16) + (ac & 0xffff) - ac = (ac >> 16) + (ac & 0xffff) - ac = (ac >> 16) + (ac & 0xffff) - return uint16(ac) +// checksumGeneric64Alternate is an alternate reference implementation of +// checksum using 64 bit arithmetic for use in testing or when an +// architecture-specific implementation is not available. +func checksumGeneric64Alternate(b []byte, initial uint16) uint16 { + var ac uint64 + + if cpu.IsBigEndian { + ac = uint64(initial) + } else { + ac = uint64(bits.ReverseBytes16(initial)) + } + + for len(b) >= 64 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + } else { + ac += uint64(binary.LittleEndian.Uint32(b[:4])) + ac += uint64(binary.LittleEndian.Uint32(b[4:8])) + ac += uint64(binary.LittleEndian.Uint32(b[8:12])) + ac += uint64(binary.LittleEndian.Uint32(b[12:16])) + ac += uint64(binary.LittleEndian.Uint32(b[16:20])) + ac += uint64(binary.LittleEndian.Uint32(b[20:24])) + ac += uint64(binary.LittleEndian.Uint32(b[24:28])) + ac += uint64(binary.LittleEndian.Uint32(b[28:32])) + ac += uint64(binary.LittleEndian.Uint32(b[32:36])) + ac += uint64(binary.LittleEndian.Uint32(b[36:40])) + ac += uint64(binary.LittleEndian.Uint32(b[40:44])) + ac += uint64(binary.LittleEndian.Uint32(b[44:48])) + ac += uint64(binary.LittleEndian.Uint32(b[48:52])) + ac += uint64(binary.LittleEndian.Uint32(b[52:56])) + ac += uint64(binary.LittleEndian.Uint32(b[56:60])) + ac += uint64(binary.LittleEndian.Uint32(b[60:64])) + } + b = b[64:] + } + if len(b) >= 32 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + } else { + ac += uint64(binary.LittleEndian.Uint32(b[:4])) + ac += uint64(binary.LittleEndian.Uint32(b[4:8])) + ac += uint64(binary.LittleEndian.Uint32(b[8:12])) + ac += uint64(binary.LittleEndian.Uint32(b[12:16])) + ac += uint64(binary.LittleEndian.Uint32(b[16:20])) + ac += uint64(binary.LittleEndian.Uint32(b[20:24])) + ac += uint64(binary.LittleEndian.Uint32(b[24:28])) + ac += uint64(binary.LittleEndian.Uint32(b[28:32])) + } + b = b[32:] + } + if len(b) >= 16 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + } else { + ac += uint64(binary.LittleEndian.Uint32(b[:4])) + ac += uint64(binary.LittleEndian.Uint32(b[4:8])) + ac += uint64(binary.LittleEndian.Uint32(b[8:12])) + ac += uint64(binary.LittleEndian.Uint32(b[12:16])) + } + b = b[16:] + } + if len(b) >= 8 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + } else { + ac += uint64(binary.LittleEndian.Uint32(b[:4])) + ac += uint64(binary.LittleEndian.Uint32(b[4:8])) + } + b = b[8:] + } + if len(b) >= 4 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b)) + } else { + ac += uint64(binary.LittleEndian.Uint32(b)) + } + b = b[4:] + } + if len(b) >= 2 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint16(b)) + } else { + ac += uint64(binary.LittleEndian.Uint16(b)) + } + b = b[2:] + } + if len(b) >= 1 { + if cpu.IsBigEndian { + ac += uint64(b[0]) << 8 + } else { + ac += uint64(b[0]) + } + } + + folded := ipChecksumFold64(ac, 0) + if !cpu.IsBigEndian { + folded = bits.ReverseBytes16(folded) + } + return folded +} + +func ipChecksumFold64(unfolded uint64, initialCarry uint64) uint16 { + sum, carry := bits.Add32(uint32(unfolded>>32), uint32(unfolded&0xffff_ffff), uint32(initialCarry)) + // if carry != 0, sum <= 0xffff_fffe, otherwise sum <= 0xffff_ffff + // therefore (sum >> 16) + (sum & 0xffff) + carry <= 0x1_fffe; so there is + // no need to save the carry flag + sum = (sum >> 16) + (sum & 0xffff) + carry + // sum <= 0x1_fffe therefore this is the last fold needed: + // if (sum >> 16) > 0 then + // (sum >> 16) == 1 && (sum & 0xffff) <= 0xfffe and therefore + // the addition will not overflow + // otherwise (sum >> 16) == 0 and sum will be unchanged + sum = (sum >> 16) + (sum & 0xffff) + return uint16(sum) +} + +func ipChecksumFold32(unfolded uint32, initialCarry uint32) uint16 { + sum := (unfolded >> 16) + (unfolded & 0xffff) + initialCarry + // sum <= 0x1_ffff: + // 0xffff + 0xffff = 0x1_fffe + // initialCarry is 0 or 1, for a combined maximum of 0x1_ffff + sum = (sum >> 16) + (sum & 0xffff) + // sum <= 0x1_0000 therefore this is the last fold needed: + // if (sum >> 16) > 0 then + // (sum >> 16) == 1 && (sum & 0xffff) == 0 and therefore + // the addition will not overflow + // otherwise (sum >> 16) == 0 and sum will be unchanged + sum = (sum >> 16) + (sum & 0xffff) + return uint16(sum) +} + +func addrPartialChecksum64(addr []byte, initial, carryIn uint64) (sum, carry uint64) { + sum, carry = initial, carryIn + switch len(addr) { + case 4: // IPv4 + if cpu.IsBigEndian { + sum, carry = bits.Add64(sum, uint64(binary.BigEndian.Uint32(addr)), carry) + } else { + sum, carry = bits.Add64(sum, uint64(binary.LittleEndian.Uint32(addr)), carry) + } + case 16: // IPv6 + if cpu.IsBigEndian { + sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr), carry) + sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr[8:]), carry) + } else { + sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr), carry) + sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr[8:]), carry) + } + default: + panic("bad addr length") + } + return sum, carry +} + +func addrPartialChecksum32(addr []byte, initial, carryIn uint32) (sum, carry uint32) { + sum, carry = initial, carryIn + switch len(addr) { + case 4: // IPv4 + if cpu.IsBigEndian { + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry) + } else { + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry) + } + case 16: // IPv6 + if cpu.IsBigEndian { + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry) + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[4:8]), carry) + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[8:12]), carry) + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[12:16]), carry) + } else { + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry) + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[4:8]), carry) + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[8:12]), carry) + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[12:16]), carry) + } + default: + panic("bad addr length") + } + return sum, carry +} + +func pseudoHeaderChecksum64(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 { + var sum uint64 + if cpu.IsBigEndian { + sum = uint64(totalLen) + uint64(protocol) + } else { + sum = uint64(bits.ReverseBytes16(totalLen)) + uint64(protocol)<<8 + } + sum, carry := addrPartialChecksum64(srcAddr, sum, 0) + sum, carry = addrPartialChecksum64(dstAddr, sum, carry) + + foldedSum := ipChecksumFold64(sum, carry) + if !cpu.IsBigEndian { + foldedSum = bits.ReverseBytes16(foldedSum) + } + return foldedSum } -func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 { - sum := checksumNoFold(srcAddr, 0) - sum = checksumNoFold(dstAddr, sum) - sum = checksumNoFold([]byte{0, protocol}, sum) - tmp := make([]byte, 2) - binary.BigEndian.PutUint16(tmp, totalLen) - return checksumNoFold(tmp, sum) +func pseudoHeaderChecksum32(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 { + var sum uint32 + if cpu.IsBigEndian { + sum = uint32(totalLen) + uint32(protocol) + } else { + sum = uint32(bits.ReverseBytes16(totalLen)) + uint32(protocol)<<8 + } + sum, carry := addrPartialChecksum32(srcAddr, sum, 0) + sum, carry = addrPartialChecksum32(dstAddr, sum, carry) + + foldedSum := ipChecksumFold32(sum, carry) + if !cpu.IsBigEndian { + foldedSum = bits.ReverseBytes16(foldedSum) + } + return foldedSum +} + +func pseudoHeaderChecksum(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 { + if strconv.IntSize < 64 { + return pseudoHeaderChecksum32(protocol, srcAddr, dstAddr, totalLen) + } + return pseudoHeaderChecksum64(protocol, srcAddr, dstAddr, totalLen) } diff --git a/vendor/github.com/tailscale/wireguard-go/tun/checksum_amd64.go b/vendor/github.com/tailscale/wireguard-go/tun/checksum_amd64.go new file mode 100644 index 0000000000..5e87693b6f --- /dev/null +++ b/vendor/github.com/tailscale/wireguard-go/tun/checksum_amd64.go @@ -0,0 +1,20 @@ +package tun + +import "golang.org/x/sys/cpu" + +// checksum computes an IP checksum starting with the provided initial value. +// The length of data should be at least 128 bytes for best performance. Smaller +// buffers will still compute a correct result. For best performance with +// smaller buffers, use shortChecksum(). +var checksum = checksumAMD64 + +func init() { + if cpu.X86.HasAVX && cpu.X86.HasAVX2 && cpu.X86.HasBMI2 { + checksum = checksumAVX2 + return + } + if cpu.X86.HasSSE2 { + checksum = checksumSSE2 + return + } +} diff --git a/vendor/github.com/tailscale/wireguard-go/tun/checksum_generated_amd64.go b/vendor/github.com/tailscale/wireguard-go/tun/checksum_generated_amd64.go new file mode 100644 index 0000000000..b4a29419b9 --- /dev/null +++ b/vendor/github.com/tailscale/wireguard-go/tun/checksum_generated_amd64.go @@ -0,0 +1,18 @@ +// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT. + +package tun + +// checksumAVX2 computes an IP checksum using amd64 v3 instructions (AVX2, BMI2) +// +//go:noescape +func checksumAVX2(b []byte, initial uint16) uint16 + +// checksumSSE2 computes an IP checksum using amd64 baseline instructions (SSE2) +// +//go:noescape +func checksumSSE2(b []byte, initial uint16) uint16 + +// checksumAMD64 computes an IP checksum using amd64 baseline instructions +// +//go:noescape +func checksumAMD64(b []byte, initial uint16) uint16 diff --git a/vendor/github.com/tailscale/wireguard-go/tun/checksum_generated_amd64.s b/vendor/github.com/tailscale/wireguard-go/tun/checksum_generated_amd64.s new file mode 100644 index 0000000000..5f2e4c5254 --- /dev/null +++ b/vendor/github.com/tailscale/wireguard-go/tun/checksum_generated_amd64.s @@ -0,0 +1,851 @@ +// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT. + +#include "textflag.h" + +DATA xmmLoadMasks<>+0(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" +DATA xmmLoadMasks<>+16(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff" +DATA xmmLoadMasks<>+32(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff" +DATA xmmLoadMasks<>+48(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff" +DATA xmmLoadMasks<>+64(SB)/16, $"\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +DATA xmmLoadMasks<>+80(SB)/16, $"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +DATA xmmLoadMasks<>+96(SB)/16, $"\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +GLOBL xmmLoadMasks<>(SB), RODATA|NOPTR, $112 + +// func checksumAVX2(b []byte, initial uint16) uint16 +// Requires: AVX, AVX2, BMI2 +TEXT ·checksumAVX2(SB), NOSPLIT|NOFRAME, $0-34 + MOVWQZX initial+24(FP), AX + XCHGB AH, AL + MOVQ b_base+0(FP), DX + MOVQ b_len+8(FP), BX + + // handle odd length buffers; they are difficult to handle in general + TESTQ $0x00000001, BX + JZ lengthIsEven + MOVBQZX -1(DX)(BX*1), CX + DECQ BX + ADDQ CX, AX + +lengthIsEven: + // handle tiny buffers (<=31 bytes) specially + CMPQ BX, $0x1f + JGT bufferIsNotTiny + XORQ CX, CX + XORQ SI, SI + XORQ DI, DI + + // shift twice to start because length is guaranteed to be even + // n = n >> 2; CF = originalN & 2 + SHRQ $0x02, BX + JNC handleTiny4 + + // tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:] + MOVWQZX (DX), CX + ADDQ $0x02, DX + +handleTiny4: + // n = n >> 1; CF = originalN & 4 + SHRQ $0x01, BX + JNC handleTiny8 + + // tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:] + MOVLQZX (DX), SI + ADDQ $0x04, DX + +handleTiny8: + // n = n >> 1; CF = originalN & 8 + SHRQ $0x01, BX + JNC handleTiny16 + + // tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:] + MOVQ (DX), DI + ADDQ $0x08, DX + +handleTiny16: + // n = n >> 1; CF = originalN & 16 + // n == 0 now, otherwise we would have branched after comparing with tinyBufferSize + SHRQ $0x01, BX + JNC handleTinyFinish + ADDQ (DX), AX + ADCQ 8(DX), AX + +handleTinyFinish: + // CF should be included from the previous add, so we use ADCQ. + // If we arrived via the JNC above, then CF=0 due to the branch condition, + // so ADCQ will still produce the correct result. + ADCQ CX, AX + ADCQ SI, AX + ADCQ DI, AX + JMP foldAndReturn + +bufferIsNotTiny: + // skip all SIMD for small buffers + CMPQ BX, $0x00000100 + JGE startSIMD + + // Accumulate carries in this register. It is never expected to overflow. + XORQ SI, SI + + // We will perform an overlapped read for buffers with length not a multiple of 8. + // Overlapped in this context means some memory will be read twice, but a shift will + // eliminate the duplicated data. This extra read is performed at the end of the buffer to + // preserve any alignment that may exist for the start of the buffer. + MOVQ BX, CX + SHRQ $0x03, BX + ANDQ $0x07, CX + JZ handleRemaining8 + LEAQ (DX)(BX*8), DI + MOVQ -8(DI)(CX*1), DI + + // Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8) + SHLQ $0x03, CX + NEGQ CX + ADDQ $0x40, CX + SHRQ CL, DI + ADDQ DI, AX + ADCQ $0x00, SI + +handleRemaining8: + SHRQ $0x01, BX + JNC handleRemaining16 + ADDQ (DX), AX + ADCQ $0x00, SI + ADDQ $0x08, DX + +handleRemaining16: + SHRQ $0x01, BX + JNC handleRemaining32 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ $0x00, SI + ADDQ $0x10, DX + +handleRemaining32: + SHRQ $0x01, BX + JNC handleRemaining64 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ $0x00, SI + ADDQ $0x20, DX + +handleRemaining64: + SHRQ $0x01, BX + JNC handleRemaining128 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ $0x00, SI + ADDQ $0x40, DX + +handleRemaining128: + SHRQ $0x01, BX + JNC handleRemainingComplete + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ 64(DX), AX + ADCQ 72(DX), AX + ADCQ 80(DX), AX + ADCQ 88(DX), AX + ADCQ 96(DX), AX + ADCQ 104(DX), AX + ADCQ 112(DX), AX + ADCQ 120(DX), AX + ADCQ $0x00, SI + ADDQ $0x80, DX + +handleRemainingComplete: + ADDQ SI, AX + JMP foldAndReturn + +startSIMD: + VPXOR Y0, Y0, Y0 + VPXOR Y1, Y1, Y1 + VPXOR Y2, Y2, Y2 + VPXOR Y3, Y3, Y3 + MOVQ BX, CX + + // Update number of bytes remaining after the loop completes + ANDQ $0xff, BX + + // Number of 256 byte iterations + SHRQ $0x08, CX + JZ smallLoop + +bigLoop: + VPMOVZXWD (DX), Y4 + VPADDD Y4, Y0, Y0 + VPMOVZXWD 16(DX), Y4 + VPADDD Y4, Y1, Y1 + VPMOVZXWD 32(DX), Y4 + VPADDD Y4, Y2, Y2 + VPMOVZXWD 48(DX), Y4 + VPADDD Y4, Y3, Y3 + VPMOVZXWD 64(DX), Y4 + VPADDD Y4, Y0, Y0 + VPMOVZXWD 80(DX), Y4 + VPADDD Y4, Y1, Y1 + VPMOVZXWD 96(DX), Y4 + VPADDD Y4, Y2, Y2 + VPMOVZXWD 112(DX), Y4 + VPADDD Y4, Y3, Y3 + VPMOVZXWD 128(DX), Y4 + VPADDD Y4, Y0, Y0 + VPMOVZXWD 144(DX), Y4 + VPADDD Y4, Y1, Y1 + VPMOVZXWD 160(DX), Y4 + VPADDD Y4, Y2, Y2 + VPMOVZXWD 176(DX), Y4 + VPADDD Y4, Y3, Y3 + VPMOVZXWD 192(DX), Y4 + VPADDD Y4, Y0, Y0 + VPMOVZXWD 208(DX), Y4 + VPADDD Y4, Y1, Y1 + VPMOVZXWD 224(DX), Y4 + VPADDD Y4, Y2, Y2 + VPMOVZXWD 240(DX), Y4 + VPADDD Y4, Y3, Y3 + ADDQ $0x00000100, DX + DECQ CX + JNZ bigLoop + CMPQ BX, $0x10 + JLT doneSmallLoop + + // now read a single 16 byte unit of data at a time +smallLoop: + VPMOVZXWD (DX), Y4 + VPADDD Y4, Y0, Y0 + ADDQ $0x10, DX + SUBQ $0x10, BX + CMPQ BX, $0x10 + JGE smallLoop + +doneSmallLoop: + CMPQ BX, $0x00 + JE doneSIMD + + // There are between 1 and 15 bytes remaining. Perform an overlapped read. + LEAQ xmmLoadMasks<>+0(SB), CX + VMOVDQU -16(DX)(BX*1), X4 + VPAND -16(CX)(BX*8), X4, X4 + VPMOVZXWD X4, Y4 + VPADDD Y4, Y0, Y0 + +doneSIMD: + // Multi-chain loop is done, combine the accumulators + VPADDD Y1, Y0, Y0 + VPADDD Y2, Y0, Y0 + VPADDD Y3, Y0, Y0 + + // extract the YMM into a pair of XMM and sum them + VEXTRACTI128 $0x01, Y0, X1 + VPADDD X0, X1, X0 + + // extract the XMM into GP64 + VPEXTRQ $0x00, X0, CX + VPEXTRQ $0x01, X0, DX + + // no more AVX code, clear upper registers to avoid SSE slowdowns + VZEROUPPER + ADDQ CX, AX + ADCQ DX, AX + +foldAndReturn: + // add CF and fold + RORXQ $0x20, AX, CX + ADCL CX, AX + RORXL $0x10, AX, CX + ADCW CX, AX + ADCW $0x00, AX + XCHGB AH, AL + MOVW AX, ret+32(FP) + RET + +// func checksumSSE2(b []byte, initial uint16) uint16 +// Requires: SSE2 +TEXT ·checksumSSE2(SB), NOSPLIT|NOFRAME, $0-34 + MOVWQZX initial+24(FP), AX + XCHGB AH, AL + MOVQ b_base+0(FP), DX + MOVQ b_len+8(FP), BX + + // handle odd length buffers; they are difficult to handle in general + TESTQ $0x00000001, BX + JZ lengthIsEven + MOVBQZX -1(DX)(BX*1), CX + DECQ BX + ADDQ CX, AX + +lengthIsEven: + // handle tiny buffers (<=31 bytes) specially + CMPQ BX, $0x1f + JGT bufferIsNotTiny + XORQ CX, CX + XORQ SI, SI + XORQ DI, DI + + // shift twice to start because length is guaranteed to be even + // n = n >> 2; CF = originalN & 2 + SHRQ $0x02, BX + JNC handleTiny4 + + // tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:] + MOVWQZX (DX), CX + ADDQ $0x02, DX + +handleTiny4: + // n = n >> 1; CF = originalN & 4 + SHRQ $0x01, BX + JNC handleTiny8 + + // tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:] + MOVLQZX (DX), SI + ADDQ $0x04, DX + +handleTiny8: + // n = n >> 1; CF = originalN & 8 + SHRQ $0x01, BX + JNC handleTiny16 + + // tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:] + MOVQ (DX), DI + ADDQ $0x08, DX + +handleTiny16: + // n = n >> 1; CF = originalN & 16 + // n == 0 now, otherwise we would have branched after comparing with tinyBufferSize + SHRQ $0x01, BX + JNC handleTinyFinish + ADDQ (DX), AX + ADCQ 8(DX), AX + +handleTinyFinish: + // CF should be included from the previous add, so we use ADCQ. + // If we arrived via the JNC above, then CF=0 due to the branch condition, + // so ADCQ will still produce the correct result. + ADCQ CX, AX + ADCQ SI, AX + ADCQ DI, AX + JMP foldAndReturn + +bufferIsNotTiny: + // skip all SIMD for small buffers + CMPQ BX, $0x00000100 + JGE startSIMD + + // Accumulate carries in this register. It is never expected to overflow. + XORQ SI, SI + + // We will perform an overlapped read for buffers with length not a multiple of 8. + // Overlapped in this context means some memory will be read twice, but a shift will + // eliminate the duplicated data. This extra read is performed at the end of the buffer to + // preserve any alignment that may exist for the start of the buffer. + MOVQ BX, CX + SHRQ $0x03, BX + ANDQ $0x07, CX + JZ handleRemaining8 + LEAQ (DX)(BX*8), DI + MOVQ -8(DI)(CX*1), DI + + // Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8) + SHLQ $0x03, CX + NEGQ CX + ADDQ $0x40, CX + SHRQ CL, DI + ADDQ DI, AX + ADCQ $0x00, SI + +handleRemaining8: + SHRQ $0x01, BX + JNC handleRemaining16 + ADDQ (DX), AX + ADCQ $0x00, SI + ADDQ $0x08, DX + +handleRemaining16: + SHRQ $0x01, BX + JNC handleRemaining32 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ $0x00, SI + ADDQ $0x10, DX + +handleRemaining32: + SHRQ $0x01, BX + JNC handleRemaining64 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ $0x00, SI + ADDQ $0x20, DX + +handleRemaining64: + SHRQ $0x01, BX + JNC handleRemaining128 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ $0x00, SI + ADDQ $0x40, DX + +handleRemaining128: + SHRQ $0x01, BX + JNC handleRemainingComplete + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ 64(DX), AX + ADCQ 72(DX), AX + ADCQ 80(DX), AX + ADCQ 88(DX), AX + ADCQ 96(DX), AX + ADCQ 104(DX), AX + ADCQ 112(DX), AX + ADCQ 120(DX), AX + ADCQ $0x00, SI + ADDQ $0x80, DX + +handleRemainingComplete: + ADDQ SI, AX + JMP foldAndReturn + +startSIMD: + PXOR X0, X0 + PXOR X1, X1 + PXOR X2, X2 + PXOR X3, X3 + PXOR X4, X4 + MOVQ BX, CX + + // Update number of bytes remaining after the loop completes + ANDQ $0xff, BX + + // Number of 256 byte iterations + SHRQ $0x08, CX + JZ smallLoop + +bigLoop: + MOVOU (DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X2 + MOVOU 16(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X1 + PADDD X6, X3 + MOVOU 32(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X2 + PADDD X6, X0 + MOVOU 48(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X3 + PADDD X6, X1 + MOVOU 64(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X2 + MOVOU 80(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X1 + PADDD X6, X3 + MOVOU 96(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X2 + PADDD X6, X0 + MOVOU 112(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X3 + PADDD X6, X1 + MOVOU 128(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X2 + MOVOU 144(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X1 + PADDD X6, X3 + MOVOU 160(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X2 + PADDD X6, X0 + MOVOU 176(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X3 + PADDD X6, X1 + MOVOU 192(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X2 + MOVOU 208(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X1 + PADDD X6, X3 + MOVOU 224(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X2 + PADDD X6, X0 + MOVOU 240(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X3 + PADDD X6, X1 + ADDQ $0x00000100, DX + DECQ CX + JNZ bigLoop + CMPQ BX, $0x10 + JLT doneSmallLoop + + // now read a single 16 byte unit of data at a time +smallLoop: + MOVOU (DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X1 + ADDQ $0x10, DX + SUBQ $0x10, BX + CMPQ BX, $0x10 + JGE smallLoop + +doneSmallLoop: + CMPQ BX, $0x00 + JE doneSIMD + + // There are between 1 and 15 bytes remaining. Perform an overlapped read. + LEAQ xmmLoadMasks<>+0(SB), CX + MOVOU -16(DX)(BX*1), X5 + PAND -16(CX)(BX*8), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X1 + +doneSIMD: + // Multi-chain loop is done, combine the accumulators + PADDD X1, X0 + PADDD X2, X0 + PADDD X3, X0 + + // extract the XMM into GP64 + MOVQ X0, CX + PSRLDQ $0x08, X0 + MOVQ X0, DX + ADDQ CX, AX + ADCQ DX, AX + +foldAndReturn: + // add CF and fold + MOVL AX, CX + ADCQ $0x00, CX + SHRQ $0x20, AX + ADDQ CX, AX + MOVWQZX AX, CX + SHRQ $0x10, AX + ADDQ CX, AX + MOVW AX, CX + SHRQ $0x10, AX + ADDW CX, AX + ADCW $0x00, AX + XCHGB AH, AL + MOVW AX, ret+32(FP) + RET + +// func checksumAMD64(b []byte, initial uint16) uint16 +TEXT ·checksumAMD64(SB), NOSPLIT|NOFRAME, $0-34 + MOVWQZX initial+24(FP), AX + XCHGB AH, AL + MOVQ b_base+0(FP), DX + MOVQ b_len+8(FP), BX + + // handle odd length buffers; they are difficult to handle in general + TESTQ $0x00000001, BX + JZ lengthIsEven + MOVBQZX -1(DX)(BX*1), CX + DECQ BX + ADDQ CX, AX + +lengthIsEven: + // handle tiny buffers (<=31 bytes) specially + CMPQ BX, $0x1f + JGT bufferIsNotTiny + XORQ CX, CX + XORQ SI, SI + XORQ DI, DI + + // shift twice to start because length is guaranteed to be even + // n = n >> 2; CF = originalN & 2 + SHRQ $0x02, BX + JNC handleTiny4 + + // tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:] + MOVWQZX (DX), CX + ADDQ $0x02, DX + +handleTiny4: + // n = n >> 1; CF = originalN & 4 + SHRQ $0x01, BX + JNC handleTiny8 + + // tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:] + MOVLQZX (DX), SI + ADDQ $0x04, DX + +handleTiny8: + // n = n >> 1; CF = originalN & 8 + SHRQ $0x01, BX + JNC handleTiny16 + + // tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:] + MOVQ (DX), DI + ADDQ $0x08, DX + +handleTiny16: + // n = n >> 1; CF = originalN & 16 + // n == 0 now, otherwise we would have branched after comparing with tinyBufferSize + SHRQ $0x01, BX + JNC handleTinyFinish + ADDQ (DX), AX + ADCQ 8(DX), AX + +handleTinyFinish: + // CF should be included from the previous add, so we use ADCQ. + // If we arrived via the JNC above, then CF=0 due to the branch condition, + // so ADCQ will still produce the correct result. + ADCQ CX, AX + ADCQ SI, AX + ADCQ DI, AX + JMP foldAndReturn + +bufferIsNotTiny: + // Number of 256 byte iterations into loop counter + MOVQ BX, CX + + // Update number of bytes remaining after the loop completes + ANDQ $0xff, BX + SHRQ $0x08, CX + JZ startCleanup + CLC + XORQ SI, SI + XORQ DI, DI + XORQ R8, R8 + XORQ R9, R9 + XORQ R10, R10 + XORQ R11, R11 + XORQ R12, R12 + +bigLoop: + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ $0x00, SI + ADDQ 32(DX), DI + ADCQ 40(DX), DI + ADCQ 48(DX), DI + ADCQ 56(DX), DI + ADCQ $0x00, R8 + ADDQ 64(DX), R9 + ADCQ 72(DX), R9 + ADCQ 80(DX), R9 + ADCQ 88(DX), R9 + ADCQ $0x00, R10 + ADDQ 96(DX), R11 + ADCQ 104(DX), R11 + ADCQ 112(DX), R11 + ADCQ 120(DX), R11 + ADCQ $0x00, R12 + ADDQ 128(DX), AX + ADCQ 136(DX), AX + ADCQ 144(DX), AX + ADCQ 152(DX), AX + ADCQ $0x00, SI + ADDQ 160(DX), DI + ADCQ 168(DX), DI + ADCQ 176(DX), DI + ADCQ 184(DX), DI + ADCQ $0x00, R8 + ADDQ 192(DX), R9 + ADCQ 200(DX), R9 + ADCQ 208(DX), R9 + ADCQ 216(DX), R9 + ADCQ $0x00, R10 + ADDQ 224(DX), R11 + ADCQ 232(DX), R11 + ADCQ 240(DX), R11 + ADCQ 248(DX), R11 + ADCQ $0x00, R12 + ADDQ $0x00000100, DX + SUBQ $0x01, CX + JNZ bigLoop + ADDQ SI, AX + ADCQ DI, AX + ADCQ R8, AX + ADCQ R9, AX + ADCQ R10, AX + ADCQ R11, AX + ADCQ R12, AX + + // accumulate CF (twice, in case the first time overflows) + ADCQ $0x00, AX + ADCQ $0x00, AX + +startCleanup: + // Accumulate carries in this register. It is never expected to overflow. + XORQ SI, SI + + // We will perform an overlapped read for buffers with length not a multiple of 8. + // Overlapped in this context means some memory will be read twice, but a shift will + // eliminate the duplicated data. This extra read is performed at the end of the buffer to + // preserve any alignment that may exist for the start of the buffer. + MOVQ BX, CX + SHRQ $0x03, BX + ANDQ $0x07, CX + JZ handleRemaining8 + LEAQ (DX)(BX*8), DI + MOVQ -8(DI)(CX*1), DI + + // Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8) + SHLQ $0x03, CX + NEGQ CX + ADDQ $0x40, CX + SHRQ CL, DI + ADDQ DI, AX + ADCQ $0x00, SI + +handleRemaining8: + SHRQ $0x01, BX + JNC handleRemaining16 + ADDQ (DX), AX + ADCQ $0x00, SI + ADDQ $0x08, DX + +handleRemaining16: + SHRQ $0x01, BX + JNC handleRemaining32 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ $0x00, SI + ADDQ $0x10, DX + +handleRemaining32: + SHRQ $0x01, BX + JNC handleRemaining64 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ $0x00, SI + ADDQ $0x20, DX + +handleRemaining64: + SHRQ $0x01, BX + JNC handleRemaining128 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ $0x00, SI + ADDQ $0x40, DX + +handleRemaining128: + SHRQ $0x01, BX + JNC handleRemainingComplete + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ 64(DX), AX + ADCQ 72(DX), AX + ADCQ 80(DX), AX + ADCQ 88(DX), AX + ADCQ 96(DX), AX + ADCQ 104(DX), AX + ADCQ 112(DX), AX + ADCQ 120(DX), AX + ADCQ $0x00, SI + ADDQ $0x80, DX + +handleRemainingComplete: + ADDQ SI, AX + +foldAndReturn: + // add CF and fold + MOVL AX, CX + ADCQ $0x00, CX + SHRQ $0x20, AX + ADDQ CX, AX + MOVWQZX AX, CX + SHRQ $0x10, AX + ADDQ CX, AX + MOVW AX, CX + SHRQ $0x10, AX + ADDW CX, AX + ADCW $0x00, AX + XCHGB AH, AL + MOVW AX, ret+32(FP) + RET diff --git a/vendor/github.com/tailscale/wireguard-go/tun/checksum_generic.go b/vendor/github.com/tailscale/wireguard-go/tun/checksum_generic.go new file mode 100644 index 0000000000..d0bfb697ca --- /dev/null +++ b/vendor/github.com/tailscale/wireguard-go/tun/checksum_generic.go @@ -0,0 +1,15 @@ +// This file contains IP checksum algorithms that are not specific to any +// architecture and don't use hardware acceleration. + +//go:build !amd64 + +package tun + +import "strconv" + +func checksum(data []byte, initial uint16) uint16 { + if strconv.IntSize < 64 { + return checksumGeneric32(data, initial) + } + return checksumGeneric64(data, initial) +} diff --git a/vendor/github.com/tailscale/wireguard-go/tun/tcp_offload_linux.go b/vendor/github.com/tailscale/wireguard-go/tun/tcp_offload_linux.go index 67288237f8..b023bbd601 100644 --- a/vendor/github.com/tailscale/wireguard-go/tun/tcp_offload_linux.go +++ b/vendor/github.com/tailscale/wireguard-go/tun/tcp_offload_linux.go @@ -260,8 +260,8 @@ func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool { addrSize = 16 } tcpTotalLen := uint16(len(pkt) - int(iphLen)) - tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen) - return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0 + tcpCSum := pseudoHeaderChecksum(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen) + return ^checksum(pkt[iphLen:], tcpCSum) == 0 } // coalesceResult represents the result of attempting to coalesce two TCP @@ -532,7 +532,7 @@ func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 srcAddrAt := offset + addrOffset srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] - psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + psum := pseudoHeaderChecksum(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) } else { hdr := virtioNetHdr{} @@ -643,8 +643,8 @@ func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffs // TCP checksum tcpHLen := int(hdr.hdrLen - hdr.csumStart) tcpLenForPseudo := uint16(tcpHLen + segmentDataLen) - tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo) - tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold) + tcpCSum := pseudoHeaderChecksum(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo) + tcpCSum = ^checksum(out[hdr.csumStart:totalLen], tcpCSum) binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum) nextSegmentDataAt += int(hdr.gsoSize) @@ -658,6 +658,6 @@ func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error { // checksum we compute. This is typically the pseudo-header checksum. initial := binary.BigEndian.Uint16(in[cSumAt:]) in[cSumAt], in[cSumAt+1] = 0, 0 - binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial))) + binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], initial)) return nil } diff --git a/vendor/go4.org/netipx/netipx.go b/vendor/go4.org/netipx/netipx.go index 53ce503570..08d57dbf71 100644 --- a/vendor/go4.org/netipx/netipx.go +++ b/vendor/go4.org/netipx/netipx.go @@ -545,3 +545,33 @@ func appendRangePrefixes(dst []netip.Prefix, makePrefix prefixMaker, a, b uint12 dst = appendRangePrefixes(dst, makePrefix, b.bitsClearedFrom(common+1), b) return dst } + +// CompareAddr returns -1 if a.Less(b), 1 if b.Less(0), else it +// returns 0. +func CompareAddr(a, b netip.Addr) int { + if a.Less(b) { + return -1 + } + if b.Less(a) { + return 1 + } + return 0 +} + +// ComparePrefix -1 if a.Addr().Less(b), 1 if +// b.Addr().Less(0), else if a and b have the same address, it +// compares their prefix bit length, returning -1, 0, or 1. +func ComparePrefix(a, b netip.Prefix) int { + aa, ba := a.Addr(), b.Addr() + if aa == ba { + ab, bb := a.Bits(), b.Bits() + if ab < bb { + return -1 + } + if bb < ab { + return 1 + } + return 0 + } + return CompareAddr(a.Addr(), b.Addr()) +} diff --git a/vendor/golang.org/x/exp/slices/cmp.go b/vendor/golang.org/x/exp/slices/cmp.go new file mode 100644 index 0000000000..fbf1934a06 --- /dev/null +++ b/vendor/golang.org/x/exp/slices/cmp.go @@ -0,0 +1,44 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package slices + +import "golang.org/x/exp/constraints" + +// min is a version of the predeclared function from the Go 1.21 release. +func min[T constraints.Ordered](a, b T) T { + if a < b || isNaN(a) { + return a + } + return b +} + +// max is a version of the predeclared function from the Go 1.21 release. +func max[T constraints.Ordered](a, b T) T { + if a > b || isNaN(a) { + return a + } + return b +} + +// cmpLess is a copy of cmp.Less from the Go 1.21 release. +func cmpLess[T constraints.Ordered](x, y T) bool { + return (isNaN(x) && !isNaN(y)) || x < y +} + +// cmpCompare is a copy of cmp.Compare from the Go 1.21 release. +func cmpCompare[T constraints.Ordered](x, y T) int { + xNaN := isNaN(x) + yNaN := isNaN(y) + if xNaN && yNaN { + return 0 + } + if xNaN || x < y { + return -1 + } + if yNaN || x > y { + return +1 + } + return 0 +} diff --git a/vendor/golang.org/x/exp/slices/slices.go b/vendor/golang.org/x/exp/slices/slices.go index 2540bd6825..5e8158bba8 100644 --- a/vendor/golang.org/x/exp/slices/slices.go +++ b/vendor/golang.org/x/exp/slices/slices.go @@ -3,23 +3,20 @@ // license that can be found in the LICENSE file. // Package slices defines various functions useful with slices of any type. -// Unless otherwise specified, these functions all apply to the elements -// of a slice at index 0 <= i < len(s). -// -// Note that the less function in IsSortedFunc, SortFunc, SortStableFunc requires a -// strict weak ordering (https://en.wikipedia.org/wiki/Weak_ordering#Strict_weak_orderings), -// or the sorting may fail to sort correctly. A common case is when sorting slices of -// floating-point numbers containing NaN values. package slices -import "golang.org/x/exp/constraints" +import ( + "unsafe" + + "golang.org/x/exp/constraints" +) // Equal reports whether two slices are equal: the same length and all // elements equal. If the lengths are different, Equal returns false. // Otherwise, the elements are compared in increasing index order, and the // comparison stops at the first unequal pair. // Floating point NaNs are not considered equal. -func Equal[E comparable](s1, s2 []E) bool { +func Equal[S ~[]E, E comparable](s1, s2 S) bool { if len(s1) != len(s2) { return false } @@ -31,12 +28,12 @@ func Equal[E comparable](s1, s2 []E) bool { return true } -// EqualFunc reports whether two slices are equal using a comparison +// EqualFunc reports whether two slices are equal using an equality // function on each pair of elements. If the lengths are different, // EqualFunc returns false. Otherwise, the elements are compared in // increasing index order, and the comparison stops at the first index // for which eq returns false. -func EqualFunc[E1, E2 any](s1 []E1, s2 []E2, eq func(E1, E2) bool) bool { +func EqualFunc[S1 ~[]E1, S2 ~[]E2, E1, E2 any](s1 S1, s2 S2, eq func(E1, E2) bool) bool { if len(s1) != len(s2) { return false } @@ -49,45 +46,37 @@ func EqualFunc[E1, E2 any](s1 []E1, s2 []E2, eq func(E1, E2) bool) bool { return true } -// Compare compares the elements of s1 and s2. -// The elements are compared sequentially, starting at index 0, +// Compare compares the elements of s1 and s2, using [cmp.Compare] on each pair +// of elements. The elements are compared sequentially, starting at index 0, // until one element is not equal to the other. // The result of comparing the first non-matching elements is returned. // If both slices are equal until one of them ends, the shorter slice is // considered less than the longer one. // The result is 0 if s1 == s2, -1 if s1 < s2, and +1 if s1 > s2. -// Comparisons involving floating point NaNs are ignored. -func Compare[E constraints.Ordered](s1, s2 []E) int { - s2len := len(s2) +func Compare[S ~[]E, E constraints.Ordered](s1, s2 S) int { for i, v1 := range s1 { - if i >= s2len { + if i >= len(s2) { return +1 } v2 := s2[i] - switch { - case v1 < v2: - return -1 - case v1 > v2: - return +1 + if c := cmpCompare(v1, v2); c != 0 { + return c } } - if len(s1) < s2len { + if len(s1) < len(s2) { return -1 } return 0 } -// CompareFunc is like Compare but uses a comparison function -// on each pair of elements. The elements are compared in increasing -// index order, and the comparisons stop after the first time cmp -// returns non-zero. +// CompareFunc is like [Compare] but uses a custom comparison function on each +// pair of elements. // The result is the first non-zero result of cmp; if cmp always // returns 0 the result is 0 if len(s1) == len(s2), -1 if len(s1) < len(s2), // and +1 if len(s1) > len(s2). -func CompareFunc[E1, E2 any](s1 []E1, s2 []E2, cmp func(E1, E2) int) int { - s2len := len(s2) +func CompareFunc[S1 ~[]E1, S2 ~[]E2, E1, E2 any](s1 S1, s2 S2, cmp func(E1, E2) int) int { for i, v1 := range s1 { - if i >= s2len { + if i >= len(s2) { return +1 } v2 := s2[i] @@ -95,7 +84,7 @@ func CompareFunc[E1, E2 any](s1 []E1, s2 []E2, cmp func(E1, E2) int) int { return c } } - if len(s1) < s2len { + if len(s1) < len(s2) { return -1 } return 0 @@ -103,7 +92,7 @@ func CompareFunc[E1, E2 any](s1 []E1, s2 []E2, cmp func(E1, E2) int) int { // Index returns the index of the first occurrence of v in s, // or -1 if not present. -func Index[E comparable](s []E, v E) int { +func Index[S ~[]E, E comparable](s S, v E) int { for i := range s { if v == s[i] { return i @@ -114,7 +103,7 @@ func Index[E comparable](s []E, v E) int { // IndexFunc returns the first index i satisfying f(s[i]), // or -1 if none do. -func IndexFunc[E any](s []E, f func(E) bool) int { +func IndexFunc[S ~[]E, E any](s S, f func(E) bool) int { for i := range s { if f(s[i]) { return i @@ -124,39 +113,104 @@ func IndexFunc[E any](s []E, f func(E) bool) int { } // Contains reports whether v is present in s. -func Contains[E comparable](s []E, v E) bool { +func Contains[S ~[]E, E comparable](s S, v E) bool { return Index(s, v) >= 0 } // ContainsFunc reports whether at least one // element e of s satisfies f(e). -func ContainsFunc[E any](s []E, f func(E) bool) bool { +func ContainsFunc[S ~[]E, E any](s S, f func(E) bool) bool { return IndexFunc(s, f) >= 0 } // Insert inserts the values v... into s at index i, // returning the modified slice. -// In the returned slice r, r[i] == v[0]. +// The elements at s[i:] are shifted up to make room. +// In the returned slice r, r[i] == v[0], +// and r[i+len(v)] == value originally at r[i]. // Insert panics if i is out of range. // This function is O(len(s) + len(v)). func Insert[S ~[]E, E any](s S, i int, v ...E) S { - tot := len(s) + len(v) - if tot <= cap(s) { - s2 := s[:tot] - copy(s2[i+len(v):], s[i:]) + m := len(v) + if m == 0 { + return s + } + n := len(s) + if i == n { + return append(s, v...) + } + if n+m > cap(s) { + // Use append rather than make so that we bump the size of + // the slice up to the next storage class. + // This is what Grow does but we don't call Grow because + // that might copy the values twice. + s2 := append(s[:i], make(S, n+m-i)...) copy(s2[i:], v) + copy(s2[i+m:], s[i:]) return s2 } - s2 := make(S, tot) - copy(s2, s[:i]) - copy(s2[i:], v) - copy(s2[i+len(v):], s[i:]) - return s2 + s = s[:n+m] + + // before: + // s: aaaaaaaabbbbccccccccdddd + // ^ ^ ^ ^ + // i i+m n n+m + // after: + // s: aaaaaaaavvvvbbbbcccccccc + // ^ ^ ^ ^ + // i i+m n n+m + // + // a are the values that don't move in s. + // v are the values copied in from v. + // b and c are the values from s that are shifted up in index. + // d are the values that get overwritten, never to be seen again. + + if !overlaps(v, s[i+m:]) { + // Easy case - v does not overlap either the c or d regions. + // (It might be in some of a or b, or elsewhere entirely.) + // The data we copy up doesn't write to v at all, so just do it. + + copy(s[i+m:], s[i:]) + + // Now we have + // s: aaaaaaaabbbbbbbbcccccccc + // ^ ^ ^ ^ + // i i+m n n+m + // Note the b values are duplicated. + + copy(s[i:], v) + + // Now we have + // s: aaaaaaaavvvvbbbbcccccccc + // ^ ^ ^ ^ + // i i+m n n+m + // That's the result we want. + return s + } + + // The hard case - v overlaps c or d. We can't just shift up + // the data because we'd move or clobber the values we're trying + // to insert. + // So instead, write v on top of d, then rotate. + copy(s[n:], v) + + // Now we have + // s: aaaaaaaabbbbccccccccvvvv + // ^ ^ ^ ^ + // i i+m n n+m + + rotateRight(s[i:], m) + + // Now we have + // s: aaaaaaaavvvvbbbbcccccccc + // ^ ^ ^ ^ + // i i+m n n+m + // That's the result we want. + return s } // Delete removes the elements s[i:j] from s, returning the modified slice. // Delete panics if s[i:j] is not a valid slice of s. -// Delete modifies the contents of the slice s; it does not create a new slice. // Delete is O(len(s)-j), so if many items must be deleted, it is better to // make a single call deleting them all together than to delete one at a time. // Delete might not modify the elements s[len(s)-(j-i):len(s)]. If those @@ -168,22 +222,113 @@ func Delete[S ~[]E, E any](s S, i, j int) S { return append(s[:i], s[j:]...) } +// DeleteFunc removes any elements from s for which del returns true, +// returning the modified slice. +// When DeleteFunc removes m elements, it might not modify the elements +// s[len(s)-m:len(s)]. If those elements contain pointers you might consider +// zeroing those elements so that objects they reference can be garbage +// collected. +func DeleteFunc[S ~[]E, E any](s S, del func(E) bool) S { + i := IndexFunc(s, del) + if i == -1 { + return s + } + // Don't start copying elements until we find one to delete. + for j := i + 1; j < len(s); j++ { + if v := s[j]; !del(v) { + s[i] = v + i++ + } + } + return s[:i] +} + // Replace replaces the elements s[i:j] by the given v, and returns the // modified slice. Replace panics if s[i:j] is not a valid slice of s. func Replace[S ~[]E, E any](s S, i, j int, v ...E) S { _ = s[i:j] // verify that i:j is a valid subslice + + if i == j { + return Insert(s, i, v...) + } + if j == len(s) { + return append(s[:i], v...) + } + tot := len(s[:i]) + len(v) + len(s[j:]) - if tot <= cap(s) { - s2 := s[:tot] - copy(s2[i+len(v):], s[j:]) + if tot > cap(s) { + // Too big to fit, allocate and copy over. + s2 := append(s[:i], make(S, tot-i)...) // See Insert copy(s2[i:], v) + copy(s2[i+len(v):], s[j:]) return s2 } - s2 := make(S, tot) - copy(s2, s[:i]) - copy(s2[i:], v) - copy(s2[i+len(v):], s[j:]) - return s2 + + r := s[:tot] + + if i+len(v) <= j { + // Easy, as v fits in the deleted portion. + copy(r[i:], v) + if i+len(v) != j { + copy(r[i+len(v):], s[j:]) + } + return r + } + + // We are expanding (v is bigger than j-i). + // The situation is something like this: + // (example has i=4,j=8,len(s)=16,len(v)=6) + // s: aaaaxxxxbbbbbbbbyy + // ^ ^ ^ ^ + // i j len(s) tot + // a: prefix of s + // x: deleted range + // b: more of s + // y: area to expand into + + if !overlaps(r[i+len(v):], v) { + // Easy, as v is not clobbered by the first copy. + copy(r[i+len(v):], s[j:]) + copy(r[i:], v) + return r + } + + // This is a situation where we don't have a single place to which + // we can copy v. Parts of it need to go to two different places. + // We want to copy the prefix of v into y and the suffix into x, then + // rotate |y| spots to the right. + // + // v[2:] v[:2] + // | | + // s: aaaavvvvbbbbbbbbvv + // ^ ^ ^ ^ + // i j len(s) tot + // + // If either of those two destinations don't alias v, then we're good. + y := len(v) - (j - i) // length of y portion + + if !overlaps(r[i:j], v) { + copy(r[i:j], v[y:]) + copy(r[len(s):], v[:y]) + rotateRight(r[i:], y) + return r + } + if !overlaps(r[len(s):], v) { + copy(r[len(s):], v[:y]) + copy(r[i:j], v[y:]) + rotateRight(r[i:], y) + return r + } + + // Now we know that v overlaps both x and y. + // That means that the entirety of b is *inside* v. + // So we don't need to preserve b at all; instead we + // can copy v first, then copy the b part of v out of + // v to the right destination. + k := startIdx(v, s[j:]) + copy(r[i:], v) + copy(r[i+len(v):], r[i+k:]) + return r } // Clone returns a copy of the slice. @@ -198,7 +343,8 @@ func Clone[S ~[]E, E any](s S) S { // Compact replaces consecutive runs of equal elements with a single copy. // This is like the uniq command found on Unix. -// Compact modifies the contents of the slice s; it does not create a new slice. +// Compact modifies the contents of the slice s and returns the modified slice, +// which may have a smaller length. // When Compact discards m elements in total, it might not modify the elements // s[len(s)-m:len(s)]. If those elements contain pointers you might consider // zeroing those elements so that objects they reference can be garbage collected. @@ -218,7 +364,8 @@ func Compact[S ~[]E, E comparable](s S) S { return s[:i] } -// CompactFunc is like Compact but uses a comparison function. +// CompactFunc is like [Compact] but uses an equality function to compare elements. +// For runs of elements that compare equal, CompactFunc keeps the first one. func CompactFunc[S ~[]E, E any](s S, eq func(E, E) bool) S { if len(s) < 2 { return s @@ -256,3 +403,97 @@ func Grow[S ~[]E, E any](s S, n int) S { func Clip[S ~[]E, E any](s S) S { return s[:len(s):len(s)] } + +// Rotation algorithm explanation: +// +// rotate left by 2 +// start with +// 0123456789 +// split up like this +// 01 234567 89 +// swap first 2 and last 2 +// 89 234567 01 +// join first parts +// 89234567 01 +// recursively rotate first left part by 2 +// 23456789 01 +// join at the end +// 2345678901 +// +// rotate left by 8 +// start with +// 0123456789 +// split up like this +// 01 234567 89 +// swap first 2 and last 2 +// 89 234567 01 +// join last parts +// 89 23456701 +// recursively rotate second part left by 6 +// 89 01234567 +// join at the end +// 8901234567 + +// TODO: There are other rotate algorithms. +// This algorithm has the desirable property that it moves each element exactly twice. +// The triple-reverse algorithm is simpler and more cache friendly, but takes more writes. +// The follow-cycles algorithm can be 1-write but it is not very cache friendly. + +// rotateLeft rotates b left by n spaces. +// s_final[i] = s_orig[i+r], wrapping around. +func rotateLeft[E any](s []E, r int) { + for r != 0 && r != len(s) { + if r*2 <= len(s) { + swap(s[:r], s[len(s)-r:]) + s = s[:len(s)-r] + } else { + swap(s[:len(s)-r], s[r:]) + s, r = s[len(s)-r:], r*2-len(s) + } + } +} +func rotateRight[E any](s []E, r int) { + rotateLeft(s, len(s)-r) +} + +// swap swaps the contents of x and y. x and y must be equal length and disjoint. +func swap[E any](x, y []E) { + for i := 0; i < len(x); i++ { + x[i], y[i] = y[i], x[i] + } +} + +// overlaps reports whether the memory ranges a[0:len(a)] and b[0:len(b)] overlap. +func overlaps[E any](a, b []E) bool { + if len(a) == 0 || len(b) == 0 { + return false + } + elemSize := unsafe.Sizeof(a[0]) + if elemSize == 0 { + return false + } + // TODO: use a runtime/unsafe facility once one becomes available. See issue 12445. + // Also see crypto/internal/alias/alias.go:AnyOverlap + return uintptr(unsafe.Pointer(&a[0])) <= uintptr(unsafe.Pointer(&b[len(b)-1]))+(elemSize-1) && + uintptr(unsafe.Pointer(&b[0])) <= uintptr(unsafe.Pointer(&a[len(a)-1]))+(elemSize-1) +} + +// startIdx returns the index in haystack where the needle starts. +// prerequisite: the needle must be aliased entirely inside the haystack. +func startIdx[E any](haystack, needle []E) int { + p := &needle[0] + for i := range haystack { + if p == &haystack[i] { + return i + } + } + // TODO: what if the overlap is by a non-integral number of Es? + panic("needle not found") +} + +// Reverse reverses the elements of the slice in place. +func Reverse[S ~[]E, E any](s S) { + for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { + s[i], s[j] = s[j], s[i] + } +} diff --git a/vendor/golang.org/x/exp/slices/sort.go b/vendor/golang.org/x/exp/slices/sort.go index 231b6448ac..b67897f76b 100644 --- a/vendor/golang.org/x/exp/slices/sort.go +++ b/vendor/golang.org/x/exp/slices/sort.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:generate go run $GOROOT/src/sort/gen_sort_variants.go -exp + package slices import ( @@ -11,57 +13,116 @@ import ( ) // Sort sorts a slice of any ordered type in ascending order. -// Sort may fail to sort correctly when sorting slices of floating-point -// numbers containing Not-a-number (NaN) values. -// Use slices.SortFunc(x, func(a, b float64) bool {return a < b || (math.IsNaN(a) && !math.IsNaN(b))}) -// instead if the input may contain NaNs. -func Sort[E constraints.Ordered](x []E) { +// When sorting floating-point numbers, NaNs are ordered before other values. +func Sort[S ~[]E, E constraints.Ordered](x S) { n := len(x) pdqsortOrdered(x, 0, n, bits.Len(uint(n))) } -// SortFunc sorts the slice x in ascending order as determined by the less function. -// This sort is not guaranteed to be stable. +// SortFunc sorts the slice x in ascending order as determined by the cmp +// function. This sort is not guaranteed to be stable. +// cmp(a, b) should return a negative number when a < b, a positive number when +// a > b and zero when a == b. // -// SortFunc requires that less is a strict weak ordering. +// SortFunc requires that cmp is a strict weak ordering. // See https://en.wikipedia.org/wiki/Weak_ordering#Strict_weak_orderings. -func SortFunc[E any](x []E, less func(a, b E) bool) { +func SortFunc[S ~[]E, E any](x S, cmp func(a, b E) int) { n := len(x) - pdqsortLessFunc(x, 0, n, bits.Len(uint(n)), less) + pdqsortCmpFunc(x, 0, n, bits.Len(uint(n)), cmp) } // SortStableFunc sorts the slice x while keeping the original order of equal -// elements, using less to compare elements. -func SortStableFunc[E any](x []E, less func(a, b E) bool) { - stableLessFunc(x, len(x), less) +// elements, using cmp to compare elements in the same way as [SortFunc]. +func SortStableFunc[S ~[]E, E any](x S, cmp func(a, b E) int) { + stableCmpFunc(x, len(x), cmp) } // IsSorted reports whether x is sorted in ascending order. -func IsSorted[E constraints.Ordered](x []E) bool { +func IsSorted[S ~[]E, E constraints.Ordered](x S) bool { for i := len(x) - 1; i > 0; i-- { - if x[i] < x[i-1] { + if cmpLess(x[i], x[i-1]) { return false } } return true } -// IsSortedFunc reports whether x is sorted in ascending order, with less as the -// comparison function. -func IsSortedFunc[E any](x []E, less func(a, b E) bool) bool { +// IsSortedFunc reports whether x is sorted in ascending order, with cmp as the +// comparison function as defined by [SortFunc]. +func IsSortedFunc[S ~[]E, E any](x S, cmp func(a, b E) int) bool { for i := len(x) - 1; i > 0; i-- { - if less(x[i], x[i-1]) { + if cmp(x[i], x[i-1]) < 0 { return false } } return true } +// Min returns the minimal value in x. It panics if x is empty. +// For floating-point numbers, Min propagates NaNs (any NaN value in x +// forces the output to be NaN). +func Min[S ~[]E, E constraints.Ordered](x S) E { + if len(x) < 1 { + panic("slices.Min: empty list") + } + m := x[0] + for i := 1; i < len(x); i++ { + m = min(m, x[i]) + } + return m +} + +// MinFunc returns the minimal value in x, using cmp to compare elements. +// It panics if x is empty. If there is more than one minimal element +// according to the cmp function, MinFunc returns the first one. +func MinFunc[S ~[]E, E any](x S, cmp func(a, b E) int) E { + if len(x) < 1 { + panic("slices.MinFunc: empty list") + } + m := x[0] + for i := 1; i < len(x); i++ { + if cmp(x[i], m) < 0 { + m = x[i] + } + } + return m +} + +// Max returns the maximal value in x. It panics if x is empty. +// For floating-point E, Max propagates NaNs (any NaN value in x +// forces the output to be NaN). +func Max[S ~[]E, E constraints.Ordered](x S) E { + if len(x) < 1 { + panic("slices.Max: empty list") + } + m := x[0] + for i := 1; i < len(x); i++ { + m = max(m, x[i]) + } + return m +} + +// MaxFunc returns the maximal value in x, using cmp to compare elements. +// It panics if x is empty. If there is more than one maximal element +// according to the cmp function, MaxFunc returns the first one. +func MaxFunc[S ~[]E, E any](x S, cmp func(a, b E) int) E { + if len(x) < 1 { + panic("slices.MaxFunc: empty list") + } + m := x[0] + for i := 1; i < len(x); i++ { + if cmp(x[i], m) > 0 { + m = x[i] + } + } + return m +} + // BinarySearch searches for target in a sorted slice and returns the position // where target is found, or the position where target would appear in the // sort order; it also returns a bool saying whether the target is really found // in the slice. The slice must be sorted in increasing order. -func BinarySearch[E constraints.Ordered](x []E, target E) (int, bool) { +func BinarySearch[S ~[]E, E constraints.Ordered](x S, target E) (int, bool) { // Inlining is faster than calling BinarySearchFunc with a lambda. n := len(x) // Define x[-1] < target and x[n] >= target. @@ -70,24 +131,24 @@ func BinarySearch[E constraints.Ordered](x []E, target E) (int, bool) { for i < j { h := int(uint(i+j) >> 1) // avoid overflow when computing h // i ≤ h < j - if x[h] < target { + if cmpLess(x[h], target) { i = h + 1 // preserves x[i-1] < target } else { j = h // preserves x[j] >= target } } // i == j, x[i-1] < target, and x[j] (= x[i]) >= target => answer is i. - return i, i < n && x[i] == target + return i, i < n && (x[i] == target || (isNaN(x[i]) && isNaN(target))) } -// BinarySearchFunc works like BinarySearch, but uses a custom comparison +// BinarySearchFunc works like [BinarySearch], but uses a custom comparison // function. The slice must be sorted in increasing order, where "increasing" // is defined by cmp. cmp should return 0 if the slice element matches // the target, a negative number if the slice element precedes the target, // or a positive number if the slice element follows the target. // cmp must implement the same ordering as the slice, such that if // cmp(a, t) < 0 and cmp(b, t) >= 0, then a must precede b in the slice. -func BinarySearchFunc[E, T any](x []E, target T, cmp func(E, T) int) (int, bool) { +func BinarySearchFunc[S ~[]E, E, T any](x S, target T, cmp func(E, T) int) (int, bool) { n := len(x) // Define cmp(x[-1], target) < 0 and cmp(x[n], target) >= 0 . // Invariant: cmp(x[i - 1], target) < 0, cmp(x[j], target) >= 0. @@ -126,3 +187,9 @@ func (r *xorshift) Next() uint64 { func nextPowerOfTwo(length int) uint { return 1 << bits.Len(uint(length)) } + +// isNaN reports whether x is a NaN without requiring the math package. +// This will always return false if T is not floating-point. +func isNaN[T constraints.Ordered](x T) bool { + return x != x +} diff --git a/vendor/golang.org/x/exp/slices/zsortfunc.go b/vendor/golang.org/x/exp/slices/zsortanyfunc.go similarity index 64% rename from vendor/golang.org/x/exp/slices/zsortfunc.go rename to vendor/golang.org/x/exp/slices/zsortanyfunc.go index 2a632476c5..06f2c7a248 100644 --- a/vendor/golang.org/x/exp/slices/zsortfunc.go +++ b/vendor/golang.org/x/exp/slices/zsortanyfunc.go @@ -6,28 +6,28 @@ package slices -// insertionSortLessFunc sorts data[a:b] using insertion sort. -func insertionSortLessFunc[E any](data []E, a, b int, less func(a, b E) bool) { +// insertionSortCmpFunc sorts data[a:b] using insertion sort. +func insertionSortCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) { for i := a + 1; i < b; i++ { - for j := i; j > a && less(data[j], data[j-1]); j-- { + for j := i; j > a && (cmp(data[j], data[j-1]) < 0); j-- { data[j], data[j-1] = data[j-1], data[j] } } } -// siftDownLessFunc implements the heap property on data[lo:hi]. +// siftDownCmpFunc implements the heap property on data[lo:hi]. // first is an offset into the array where the root of the heap lies. -func siftDownLessFunc[E any](data []E, lo, hi, first int, less func(a, b E) bool) { +func siftDownCmpFunc[E any](data []E, lo, hi, first int, cmp func(a, b E) int) { root := lo for { child := 2*root + 1 if child >= hi { break } - if child+1 < hi && less(data[first+child], data[first+child+1]) { + if child+1 < hi && (cmp(data[first+child], data[first+child+1]) < 0) { child++ } - if !less(data[first+root], data[first+child]) { + if !(cmp(data[first+root], data[first+child]) < 0) { return } data[first+root], data[first+child] = data[first+child], data[first+root] @@ -35,30 +35,30 @@ func siftDownLessFunc[E any](data []E, lo, hi, first int, less func(a, b E) bool } } -func heapSortLessFunc[E any](data []E, a, b int, less func(a, b E) bool) { +func heapSortCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) { first := a lo := 0 hi := b - a // Build heap with greatest element at top. for i := (hi - 1) / 2; i >= 0; i-- { - siftDownLessFunc(data, i, hi, first, less) + siftDownCmpFunc(data, i, hi, first, cmp) } // Pop elements, largest first, into end of data. for i := hi - 1; i >= 0; i-- { data[first], data[first+i] = data[first+i], data[first] - siftDownLessFunc(data, lo, i, first, less) + siftDownCmpFunc(data, lo, i, first, cmp) } } -// pdqsortLessFunc sorts data[a:b]. +// pdqsortCmpFunc sorts data[a:b]. // The algorithm based on pattern-defeating quicksort(pdqsort), but without the optimizations from BlockQuicksort. // pdqsort paper: https://arxiv.org/pdf/2106.05123.pdf // C++ implementation: https://github.com/orlp/pdqsort // Rust implementation: https://docs.rs/pdqsort/latest/pdqsort/ // limit is the number of allowed bad (very unbalanced) pivots before falling back to heapsort. -func pdqsortLessFunc[E any](data []E, a, b, limit int, less func(a, b E) bool) { +func pdqsortCmpFunc[E any](data []E, a, b, limit int, cmp func(a, b E) int) { const maxInsertion = 12 var ( @@ -70,25 +70,25 @@ func pdqsortLessFunc[E any](data []E, a, b, limit int, less func(a, b E) bool) { length := b - a if length <= maxInsertion { - insertionSortLessFunc(data, a, b, less) + insertionSortCmpFunc(data, a, b, cmp) return } // Fall back to heapsort if too many bad choices were made. if limit == 0 { - heapSortLessFunc(data, a, b, less) + heapSortCmpFunc(data, a, b, cmp) return } // If the last partitioning was imbalanced, we need to breaking patterns. if !wasBalanced { - breakPatternsLessFunc(data, a, b, less) + breakPatternsCmpFunc(data, a, b, cmp) limit-- } - pivot, hint := choosePivotLessFunc(data, a, b, less) + pivot, hint := choosePivotCmpFunc(data, a, b, cmp) if hint == decreasingHint { - reverseRangeLessFunc(data, a, b, less) + reverseRangeCmpFunc(data, a, b, cmp) // The chosen pivot was pivot-a elements after the start of the array. // After reversing it is pivot-a elements before the end of the array. // The idea came from Rust's implementation. @@ -98,48 +98,48 @@ func pdqsortLessFunc[E any](data []E, a, b, limit int, less func(a, b E) bool) { // The slice is likely already sorted. if wasBalanced && wasPartitioned && hint == increasingHint { - if partialInsertionSortLessFunc(data, a, b, less) { + if partialInsertionSortCmpFunc(data, a, b, cmp) { return } } // Probably the slice contains many duplicate elements, partition the slice into // elements equal to and elements greater than the pivot. - if a > 0 && !less(data[a-1], data[pivot]) { - mid := partitionEqualLessFunc(data, a, b, pivot, less) + if a > 0 && !(cmp(data[a-1], data[pivot]) < 0) { + mid := partitionEqualCmpFunc(data, a, b, pivot, cmp) a = mid continue } - mid, alreadyPartitioned := partitionLessFunc(data, a, b, pivot, less) + mid, alreadyPartitioned := partitionCmpFunc(data, a, b, pivot, cmp) wasPartitioned = alreadyPartitioned leftLen, rightLen := mid-a, b-mid balanceThreshold := length / 8 if leftLen < rightLen { wasBalanced = leftLen >= balanceThreshold - pdqsortLessFunc(data, a, mid, limit, less) + pdqsortCmpFunc(data, a, mid, limit, cmp) a = mid + 1 } else { wasBalanced = rightLen >= balanceThreshold - pdqsortLessFunc(data, mid+1, b, limit, less) + pdqsortCmpFunc(data, mid+1, b, limit, cmp) b = mid } } } -// partitionLessFunc does one quicksort partition. +// partitionCmpFunc does one quicksort partition. // Let p = data[pivot] // Moves elements in data[a:b] around, so that data[i]

=p for inewpivot. // On return, data[newpivot] = p -func partitionLessFunc[E any](data []E, a, b, pivot int, less func(a, b E) bool) (newpivot int, alreadyPartitioned bool) { +func partitionCmpFunc[E any](data []E, a, b, pivot int, cmp func(a, b E) int) (newpivot int, alreadyPartitioned bool) { data[a], data[pivot] = data[pivot], data[a] i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned - for i <= j && less(data[i], data[a]) { + for i <= j && (cmp(data[i], data[a]) < 0) { i++ } - for i <= j && !less(data[j], data[a]) { + for i <= j && !(cmp(data[j], data[a]) < 0) { j-- } if i > j { @@ -151,10 +151,10 @@ func partitionLessFunc[E any](data []E, a, b, pivot int, less func(a, b E) bool) j-- for { - for i <= j && less(data[i], data[a]) { + for i <= j && (cmp(data[i], data[a]) < 0) { i++ } - for i <= j && !less(data[j], data[a]) { + for i <= j && !(cmp(data[j], data[a]) < 0) { j-- } if i > j { @@ -168,17 +168,17 @@ func partitionLessFunc[E any](data []E, a, b, pivot int, less func(a, b E) bool) return j, false } -// partitionEqualLessFunc partitions data[a:b] into elements equal to data[pivot] followed by elements greater than data[pivot]. +// partitionEqualCmpFunc partitions data[a:b] into elements equal to data[pivot] followed by elements greater than data[pivot]. // It assumed that data[a:b] does not contain elements smaller than the data[pivot]. -func partitionEqualLessFunc[E any](data []E, a, b, pivot int, less func(a, b E) bool) (newpivot int) { +func partitionEqualCmpFunc[E any](data []E, a, b, pivot int, cmp func(a, b E) int) (newpivot int) { data[a], data[pivot] = data[pivot], data[a] i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned for { - for i <= j && !less(data[a], data[i]) { + for i <= j && !(cmp(data[a], data[i]) < 0) { i++ } - for i <= j && less(data[a], data[j]) { + for i <= j && (cmp(data[a], data[j]) < 0) { j-- } if i > j { @@ -191,15 +191,15 @@ func partitionEqualLessFunc[E any](data []E, a, b, pivot int, less func(a, b E) return i } -// partialInsertionSortLessFunc partially sorts a slice, returns true if the slice is sorted at the end. -func partialInsertionSortLessFunc[E any](data []E, a, b int, less func(a, b E) bool) bool { +// partialInsertionSortCmpFunc partially sorts a slice, returns true if the slice is sorted at the end. +func partialInsertionSortCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) bool { const ( maxSteps = 5 // maximum number of adjacent out-of-order pairs that will get shifted shortestShifting = 50 // don't shift any elements on short arrays ) i := a + 1 for j := 0; j < maxSteps; j++ { - for i < b && !less(data[i], data[i-1]) { + for i < b && !(cmp(data[i], data[i-1]) < 0) { i++ } @@ -216,7 +216,7 @@ func partialInsertionSortLessFunc[E any](data []E, a, b int, less func(a, b E) b // Shift the smaller one to the left. if i-a >= 2 { for j := i - 1; j >= 1; j-- { - if !less(data[j], data[j-1]) { + if !(cmp(data[j], data[j-1]) < 0) { break } data[j], data[j-1] = data[j-1], data[j] @@ -225,7 +225,7 @@ func partialInsertionSortLessFunc[E any](data []E, a, b int, less func(a, b E) b // Shift the greater one to the right. if b-i >= 2 { for j := i + 1; j < b; j++ { - if !less(data[j], data[j-1]) { + if !(cmp(data[j], data[j-1]) < 0) { break } data[j], data[j-1] = data[j-1], data[j] @@ -235,9 +235,9 @@ func partialInsertionSortLessFunc[E any](data []E, a, b int, less func(a, b E) b return false } -// breakPatternsLessFunc scatters some elements around in an attempt to break some patterns +// breakPatternsCmpFunc scatters some elements around in an attempt to break some patterns // that might cause imbalanced partitions in quicksort. -func breakPatternsLessFunc[E any](data []E, a, b int, less func(a, b E) bool) { +func breakPatternsCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) { length := b - a if length >= 8 { random := xorshift(length) @@ -253,12 +253,12 @@ func breakPatternsLessFunc[E any](data []E, a, b int, less func(a, b E) bool) { } } -// choosePivotLessFunc chooses a pivot in data[a:b]. +// choosePivotCmpFunc chooses a pivot in data[a:b]. // // [0,8): chooses a static pivot. // [8,shortestNinther): uses the simple median-of-three method. // [shortestNinther,∞): uses the Tukey ninther method. -func choosePivotLessFunc[E any](data []E, a, b int, less func(a, b E) bool) (pivot int, hint sortedHint) { +func choosePivotCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) (pivot int, hint sortedHint) { const ( shortestNinther = 50 maxSwaps = 4 * 3 @@ -276,12 +276,12 @@ func choosePivotLessFunc[E any](data []E, a, b int, less func(a, b E) bool) (piv if l >= 8 { if l >= shortestNinther { // Tukey ninther method, the idea came from Rust's implementation. - i = medianAdjacentLessFunc(data, i, &swaps, less) - j = medianAdjacentLessFunc(data, j, &swaps, less) - k = medianAdjacentLessFunc(data, k, &swaps, less) + i = medianAdjacentCmpFunc(data, i, &swaps, cmp) + j = medianAdjacentCmpFunc(data, j, &swaps, cmp) + k = medianAdjacentCmpFunc(data, k, &swaps, cmp) } // Find the median among i, j, k and stores it into j. - j = medianLessFunc(data, i, j, k, &swaps, less) + j = medianCmpFunc(data, i, j, k, &swaps, cmp) } switch swaps { @@ -294,29 +294,29 @@ func choosePivotLessFunc[E any](data []E, a, b int, less func(a, b E) bool) (piv } } -// order2LessFunc returns x,y where data[x] <= data[y], where x,y=a,b or x,y=b,a. -func order2LessFunc[E any](data []E, a, b int, swaps *int, less func(a, b E) bool) (int, int) { - if less(data[b], data[a]) { +// order2CmpFunc returns x,y where data[x] <= data[y], where x,y=a,b or x,y=b,a. +func order2CmpFunc[E any](data []E, a, b int, swaps *int, cmp func(a, b E) int) (int, int) { + if cmp(data[b], data[a]) < 0 { *swaps++ return b, a } return a, b } -// medianLessFunc returns x where data[x] is the median of data[a],data[b],data[c], where x is a, b, or c. -func medianLessFunc[E any](data []E, a, b, c int, swaps *int, less func(a, b E) bool) int { - a, b = order2LessFunc(data, a, b, swaps, less) - b, c = order2LessFunc(data, b, c, swaps, less) - a, b = order2LessFunc(data, a, b, swaps, less) +// medianCmpFunc returns x where data[x] is the median of data[a],data[b],data[c], where x is a, b, or c. +func medianCmpFunc[E any](data []E, a, b, c int, swaps *int, cmp func(a, b E) int) int { + a, b = order2CmpFunc(data, a, b, swaps, cmp) + b, c = order2CmpFunc(data, b, c, swaps, cmp) + a, b = order2CmpFunc(data, a, b, swaps, cmp) return b } -// medianAdjacentLessFunc finds the median of data[a - 1], data[a], data[a + 1] and stores the index into a. -func medianAdjacentLessFunc[E any](data []E, a int, swaps *int, less func(a, b E) bool) int { - return medianLessFunc(data, a-1, a, a+1, swaps, less) +// medianAdjacentCmpFunc finds the median of data[a - 1], data[a], data[a + 1] and stores the index into a. +func medianAdjacentCmpFunc[E any](data []E, a int, swaps *int, cmp func(a, b E) int) int { + return medianCmpFunc(data, a-1, a, a+1, swaps, cmp) } -func reverseRangeLessFunc[E any](data []E, a, b int, less func(a, b E) bool) { +func reverseRangeCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) { i := a j := b - 1 for i < j { @@ -326,37 +326,37 @@ func reverseRangeLessFunc[E any](data []E, a, b int, less func(a, b E) bool) { } } -func swapRangeLessFunc[E any](data []E, a, b, n int, less func(a, b E) bool) { +func swapRangeCmpFunc[E any](data []E, a, b, n int, cmp func(a, b E) int) { for i := 0; i < n; i++ { data[a+i], data[b+i] = data[b+i], data[a+i] } } -func stableLessFunc[E any](data []E, n int, less func(a, b E) bool) { +func stableCmpFunc[E any](data []E, n int, cmp func(a, b E) int) { blockSize := 20 // must be > 0 a, b := 0, blockSize for b <= n { - insertionSortLessFunc(data, a, b, less) + insertionSortCmpFunc(data, a, b, cmp) a = b b += blockSize } - insertionSortLessFunc(data, a, n, less) + insertionSortCmpFunc(data, a, n, cmp) for blockSize < n { a, b = 0, 2*blockSize for b <= n { - symMergeLessFunc(data, a, a+blockSize, b, less) + symMergeCmpFunc(data, a, a+blockSize, b, cmp) a = b b += 2 * blockSize } if m := a + blockSize; m < n { - symMergeLessFunc(data, a, m, n, less) + symMergeCmpFunc(data, a, m, n, cmp) } blockSize *= 2 } } -// symMergeLessFunc merges the two sorted subsequences data[a:m] and data[m:b] using +// symMergeCmpFunc merges the two sorted subsequences data[a:m] and data[m:b] using // the SymMerge algorithm from Pok-Son Kim and Arne Kutzner, "Stable Minimum // Storage Merging by Symmetric Comparisons", in Susanne Albers and Tomasz // Radzik, editors, Algorithms - ESA 2004, volume 3221 of Lecture Notes in @@ -375,7 +375,7 @@ func stableLessFunc[E any](data []E, n int, less func(a, b E) bool) { // symMerge assumes non-degenerate arguments: a < m && m < b. // Having the caller check this condition eliminates many leaf recursion calls, // which improves performance. -func symMergeLessFunc[E any](data []E, a, m, b int, less func(a, b E) bool) { +func symMergeCmpFunc[E any](data []E, a, m, b int, cmp func(a, b E) int) { // Avoid unnecessary recursions of symMerge // by direct insertion of data[a] into data[m:b] // if data[a:m] only contains one element. @@ -387,7 +387,7 @@ func symMergeLessFunc[E any](data []E, a, m, b int, less func(a, b E) bool) { j := b for i < j { h := int(uint(i+j) >> 1) - if less(data[h], data[a]) { + if cmp(data[h], data[a]) < 0 { i = h + 1 } else { j = h @@ -411,7 +411,7 @@ func symMergeLessFunc[E any](data []E, a, m, b int, less func(a, b E) bool) { j := m for i < j { h := int(uint(i+j) >> 1) - if !less(data[m], data[h]) { + if !(cmp(data[m], data[h]) < 0) { i = h + 1 } else { j = h @@ -438,7 +438,7 @@ func symMergeLessFunc[E any](data []E, a, m, b int, less func(a, b E) bool) { for start < r { c := int(uint(start+r) >> 1) - if !less(data[p-c], data[c]) { + if !(cmp(data[p-c], data[c]) < 0) { start = c + 1 } else { r = c @@ -447,33 +447,33 @@ func symMergeLessFunc[E any](data []E, a, m, b int, less func(a, b E) bool) { end := n - start if start < m && m < end { - rotateLessFunc(data, start, m, end, less) + rotateCmpFunc(data, start, m, end, cmp) } if a < start && start < mid { - symMergeLessFunc(data, a, start, mid, less) + symMergeCmpFunc(data, a, start, mid, cmp) } if mid < end && end < b { - symMergeLessFunc(data, mid, end, b, less) + symMergeCmpFunc(data, mid, end, b, cmp) } } -// rotateLessFunc rotates two consecutive blocks u = data[a:m] and v = data[m:b] in data: +// rotateCmpFunc rotates two consecutive blocks u = data[a:m] and v = data[m:b] in data: // Data of the form 'x u v y' is changed to 'x v u y'. // rotate performs at most b-a many calls to data.Swap, // and it assumes non-degenerate arguments: a < m && m < b. -func rotateLessFunc[E any](data []E, a, m, b int, less func(a, b E) bool) { +func rotateCmpFunc[E any](data []E, a, m, b int, cmp func(a, b E) int) { i := m - a j := b - m for i != j { if i > j { - swapRangeLessFunc(data, m-i, m, j, less) + swapRangeCmpFunc(data, m-i, m, j, cmp) i -= j } else { - swapRangeLessFunc(data, m-i, m+j-i, i, less) + swapRangeCmpFunc(data, m-i, m+j-i, i, cmp) j -= i } } // i == j - swapRangeLessFunc(data, m-i, m, i, less) + swapRangeCmpFunc(data, m-i, m, i, cmp) } diff --git a/vendor/golang.org/x/exp/slices/zsortordered.go b/vendor/golang.org/x/exp/slices/zsortordered.go index efaa1c8b71..99b47c3986 100644 --- a/vendor/golang.org/x/exp/slices/zsortordered.go +++ b/vendor/golang.org/x/exp/slices/zsortordered.go @@ -11,7 +11,7 @@ import "golang.org/x/exp/constraints" // insertionSortOrdered sorts data[a:b] using insertion sort. func insertionSortOrdered[E constraints.Ordered](data []E, a, b int) { for i := a + 1; i < b; i++ { - for j := i; j > a && (data[j] < data[j-1]); j-- { + for j := i; j > a && cmpLess(data[j], data[j-1]); j-- { data[j], data[j-1] = data[j-1], data[j] } } @@ -26,10 +26,10 @@ func siftDownOrdered[E constraints.Ordered](data []E, lo, hi, first int) { if child >= hi { break } - if child+1 < hi && (data[first+child] < data[first+child+1]) { + if child+1 < hi && cmpLess(data[first+child], data[first+child+1]) { child++ } - if !(data[first+root] < data[first+child]) { + if !cmpLess(data[first+root], data[first+child]) { return } data[first+root], data[first+child] = data[first+child], data[first+root] @@ -107,7 +107,7 @@ func pdqsortOrdered[E constraints.Ordered](data []E, a, b, limit int) { // Probably the slice contains many duplicate elements, partition the slice into // elements equal to and elements greater than the pivot. - if a > 0 && !(data[a-1] < data[pivot]) { + if a > 0 && !cmpLess(data[a-1], data[pivot]) { mid := partitionEqualOrdered(data, a, b, pivot) a = mid continue @@ -138,10 +138,10 @@ func partitionOrdered[E constraints.Ordered](data []E, a, b, pivot int) (newpivo data[a], data[pivot] = data[pivot], data[a] i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned - for i <= j && (data[i] < data[a]) { + for i <= j && cmpLess(data[i], data[a]) { i++ } - for i <= j && !(data[j] < data[a]) { + for i <= j && !cmpLess(data[j], data[a]) { j-- } if i > j { @@ -153,10 +153,10 @@ func partitionOrdered[E constraints.Ordered](data []E, a, b, pivot int) (newpivo j-- for { - for i <= j && (data[i] < data[a]) { + for i <= j && cmpLess(data[i], data[a]) { i++ } - for i <= j && !(data[j] < data[a]) { + for i <= j && !cmpLess(data[j], data[a]) { j-- } if i > j { @@ -177,10 +177,10 @@ func partitionEqualOrdered[E constraints.Ordered](data []E, a, b, pivot int) (ne i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned for { - for i <= j && !(data[a] < data[i]) { + for i <= j && !cmpLess(data[a], data[i]) { i++ } - for i <= j && (data[a] < data[j]) { + for i <= j && cmpLess(data[a], data[j]) { j-- } if i > j { @@ -201,7 +201,7 @@ func partialInsertionSortOrdered[E constraints.Ordered](data []E, a, b int) bool ) i := a + 1 for j := 0; j < maxSteps; j++ { - for i < b && !(data[i] < data[i-1]) { + for i < b && !cmpLess(data[i], data[i-1]) { i++ } @@ -218,7 +218,7 @@ func partialInsertionSortOrdered[E constraints.Ordered](data []E, a, b int) bool // Shift the smaller one to the left. if i-a >= 2 { for j := i - 1; j >= 1; j-- { - if !(data[j] < data[j-1]) { + if !cmpLess(data[j], data[j-1]) { break } data[j], data[j-1] = data[j-1], data[j] @@ -227,7 +227,7 @@ func partialInsertionSortOrdered[E constraints.Ordered](data []E, a, b int) bool // Shift the greater one to the right. if b-i >= 2 { for j := i + 1; j < b; j++ { - if !(data[j] < data[j-1]) { + if !cmpLess(data[j], data[j-1]) { break } data[j], data[j-1] = data[j-1], data[j] @@ -298,7 +298,7 @@ func choosePivotOrdered[E constraints.Ordered](data []E, a, b int) (pivot int, h // order2Ordered returns x,y where data[x] <= data[y], where x,y=a,b or x,y=b,a. func order2Ordered[E constraints.Ordered](data []E, a, b int, swaps *int) (int, int) { - if data[b] < data[a] { + if cmpLess(data[b], data[a]) { *swaps++ return b, a } @@ -389,7 +389,7 @@ func symMergeOrdered[E constraints.Ordered](data []E, a, m, b int) { j := b for i < j { h := int(uint(i+j) >> 1) - if data[h] < data[a] { + if cmpLess(data[h], data[a]) { i = h + 1 } else { j = h @@ -413,7 +413,7 @@ func symMergeOrdered[E constraints.Ordered](data []E, a, m, b int) { j := m for i < j { h := int(uint(i+j) >> 1) - if !(data[m] < data[h]) { + if !cmpLess(data[m], data[h]) { i = h + 1 } else { j = h @@ -440,7 +440,7 @@ func symMergeOrdered[E constraints.Ordered](data []E, a, m, b int) { for start < r { c := int(uint(start+r) >> 1) - if !(data[p-c] < data[c]) { + if !cmpLess(data[p-c], data[c]) { start = c + 1 } else { r = c diff --git a/vendor/golang.org/x/exp/slog/doc.go b/vendor/golang.org/x/exp/slog/doc.go index 3b242591fc..4beaf86748 100644 --- a/vendor/golang.org/x/exp/slog/doc.go +++ b/vendor/golang.org/x/exp/slog/doc.go @@ -174,9 +174,9 @@ argument, as do their corresponding top-level functions. Although the convenience methods on Logger (Info and so on) and the corresponding top-level functions do not take a context, the alternatives ending -in "Ctx" do. For example, +in "Context" do. For example, - slog.InfoCtx(ctx, "message") + slog.InfoContext(ctx, "message") It is recommended to pass a context to an output method if one is available. diff --git a/vendor/golang.org/x/exp/slog/logger.go b/vendor/golang.org/x/exp/slog/logger.go index 6ad93bf8c9..e87ec9936c 100644 --- a/vendor/golang.org/x/exp/slog/logger.go +++ b/vendor/golang.org/x/exp/slog/logger.go @@ -167,7 +167,13 @@ func (l *Logger) Debug(msg string, args ...any) { l.log(nil, LevelDebug, msg, args...) } +// DebugContext logs at LevelDebug with the given context. +func (l *Logger) DebugContext(ctx context.Context, msg string, args ...any) { + l.log(ctx, LevelDebug, msg, args...) +} + // DebugCtx logs at LevelDebug with the given context. +// Deprecated: Use Logger.DebugContext. func (l *Logger) DebugCtx(ctx context.Context, msg string, args ...any) { l.log(ctx, LevelDebug, msg, args...) } @@ -177,7 +183,13 @@ func (l *Logger) Info(msg string, args ...any) { l.log(nil, LevelInfo, msg, args...) } +// InfoContext logs at LevelInfo with the given context. +func (l *Logger) InfoContext(ctx context.Context, msg string, args ...any) { + l.log(ctx, LevelInfo, msg, args...) +} + // InfoCtx logs at LevelInfo with the given context. +// Deprecated: Use Logger.InfoContext. func (l *Logger) InfoCtx(ctx context.Context, msg string, args ...any) { l.log(ctx, LevelInfo, msg, args...) } @@ -187,7 +199,13 @@ func (l *Logger) Warn(msg string, args ...any) { l.log(nil, LevelWarn, msg, args...) } +// WarnContext logs at LevelWarn with the given context. +func (l *Logger) WarnContext(ctx context.Context, msg string, args ...any) { + l.log(ctx, LevelWarn, msg, args...) +} + // WarnCtx logs at LevelWarn with the given context. +// Deprecated: Use Logger.WarnContext. func (l *Logger) WarnCtx(ctx context.Context, msg string, args ...any) { l.log(ctx, LevelWarn, msg, args...) } @@ -197,7 +215,13 @@ func (l *Logger) Error(msg string, args ...any) { l.log(nil, LevelError, msg, args...) } +// ErrorContext logs at LevelError with the given context. +func (l *Logger) ErrorContext(ctx context.Context, msg string, args ...any) { + l.log(ctx, LevelError, msg, args...) +} + // ErrorCtx logs at LevelError with the given context. +// Deprecated: Use Logger.ErrorContext. func (l *Logger) ErrorCtx(ctx context.Context, msg string, args ...any) { l.log(ctx, LevelError, msg, args...) } @@ -249,8 +273,8 @@ func Debug(msg string, args ...any) { Default().log(nil, LevelDebug, msg, args...) } -// DebugCtx calls Logger.DebugCtx on the default logger. -func DebugCtx(ctx context.Context, msg string, args ...any) { +// DebugContext calls Logger.DebugContext on the default logger. +func DebugContext(ctx context.Context, msg string, args ...any) { Default().log(ctx, LevelDebug, msg, args...) } @@ -259,8 +283,8 @@ func Info(msg string, args ...any) { Default().log(nil, LevelInfo, msg, args...) } -// InfoCtx calls Logger.InfoCtx on the default logger. -func InfoCtx(ctx context.Context, msg string, args ...any) { +// InfoContext calls Logger.InfoContext on the default logger. +func InfoContext(ctx context.Context, msg string, args ...any) { Default().log(ctx, LevelInfo, msg, args...) } @@ -269,8 +293,8 @@ func Warn(msg string, args ...any) { Default().log(nil, LevelWarn, msg, args...) } -// WarnCtx calls Logger.WarnCtx on the default logger. -func WarnCtx(ctx context.Context, msg string, args ...any) { +// WarnContext calls Logger.WarnContext on the default logger. +func WarnContext(ctx context.Context, msg string, args ...any) { Default().log(ctx, LevelWarn, msg, args...) } @@ -279,7 +303,31 @@ func Error(msg string, args ...any) { Default().log(nil, LevelError, msg, args...) } -// ErrorCtx calls Logger.ErrorCtx on the default logger. +// ErrorContext calls Logger.ErrorContext on the default logger. +func ErrorContext(ctx context.Context, msg string, args ...any) { + Default().log(ctx, LevelError, msg, args...) +} + +// DebugCtx calls Logger.DebugContext on the default logger. +// Deprecated: call DebugContext. +func DebugCtx(ctx context.Context, msg string, args ...any) { + Default().log(ctx, LevelDebug, msg, args...) +} + +// InfoCtx calls Logger.InfoContext on the default logger. +// Deprecated: call InfoContext. +func InfoCtx(ctx context.Context, msg string, args ...any) { + Default().log(ctx, LevelInfo, msg, args...) +} + +// WarnCtx calls Logger.WarnContext on the default logger. +// Deprecated: call WarnContext. +func WarnCtx(ctx context.Context, msg string, args ...any) { + Default().log(ctx, LevelWarn, msg, args...) +} + +// ErrorCtx calls Logger.ErrorContext on the default logger. +// Deprecated: call ErrorContext. func ErrorCtx(ctx context.Context, msg string, args ...any) { Default().log(ctx, LevelError, msg, args...) } diff --git a/vendor/modules.txt b/vendor/modules.txt index 87b3ee71f5..39edbf04c7 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -257,9 +257,19 @@ github.com/chromedp/sysutil # github.com/coreos/go-iptables v0.6.0 ## explicit; go 1.16 github.com/coreos/go-iptables/iptables +# github.com/coreos/go-systemd/v22 v22.4.0 +## explicit; go 1.12 +github.com/coreos/go-systemd/v22/dbus # github.com/davecgh/go-spew v1.1.1 ## explicit github.com/davecgh/go-spew/spew +# github.com/dblohm7/wingoes v0.0.0-20230821191801-fc76608aecf0 +## explicit; go 1.18 +github.com/dblohm7/wingoes +github.com/dblohm7/wingoes/com +github.com/dblohm7/wingoes/com/automation +github.com/dblohm7/wingoes/internal +github.com/dblohm7/wingoes/pe # github.com/demisto/goxforce v0.0.0-20160322194047-db8357535b1d ## explicit github.com/demisto/goxforce @@ -305,7 +315,7 @@ github.com/gobwas/pool/pbytes ## explicit; go 1.15 github.com/gobwas/ws github.com/gobwas/ws/wsutil -# github.com/godbus/dbus/v5 v5.1.0 +# github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 ## explicit; go 1.12 github.com/godbus/dbus/v5 # github.com/gofrs/uuid v4.4.0+incompatible @@ -332,6 +342,14 @@ github.com/google/go-cmp/cmp/internal/diff github.com/google/go-cmp/cmp/internal/flags github.com/google/go-cmp/cmp/internal/function github.com/google/go-cmp/cmp/internal/value +# github.com/google/nftables v0.1.1-0.20230115205135-9aa6fdf5a28c +## explicit; go 1.17 +github.com/google/nftables +github.com/google/nftables/alignedbuff +github.com/google/nftables/binaryutil +github.com/google/nftables/expr +github.com/google/nftables/internal/parseexprfunc +github.com/google/nftables/xt # github.com/google/uuid v1.3.0 ## explicit github.com/google/uuid @@ -468,6 +486,7 @@ github.com/mdlayher/genetlink ## explicit; go 1.18 github.com/mdlayher/netlink github.com/mdlayher/netlink/nlenc +github.com/mdlayher/netlink/nltest # github.com/mdlayher/sdnotify v1.0.0 ## explicit; go 1.18 github.com/mdlayher/sdnotify @@ -584,14 +603,15 @@ github.com/spf13/cobra # github.com/spf13/pflag v1.0.5 ## explicit; go 1.12 github.com/spf13/pflag -# github.com/stretchr/testify v1.8.2 -## explicit; go 1.13 +# github.com/stretchr/testify v1.8.4 +## explicit; go 1.20 github.com/stretchr/testify/assert # github.com/tailscale/certstore v0.1.1-0.20220316223106-78d6e1c49d8d ## explicit; go 1.12 github.com/tailscale/certstore -# github.com/tailscale/golang-x-crypto v0.0.0-20221115211329-17a3db2c30d2 +# github.com/tailscale/golang-x-crypto v0.0.0-20230713185742-f0b76a10a08e ## explicit; go 1.17 +github.com/tailscale/golang-x-crypto/acme github.com/tailscale/golang-x-crypto/chacha20 github.com/tailscale/golang-x-crypto/internal/alias github.com/tailscale/golang-x-crypto/ssh @@ -607,7 +627,7 @@ github.com/tailscale/goupnp/ssdp # github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85 ## explicit; go 1.12 github.com/tailscale/netlink -# github.com/tailscale/wireguard-go v0.0.0-20230410165232-af172621b4dd +# github.com/tailscale/wireguard-go v0.0.0-20230824215414-93bd5cbf7fd8 ## explicit; go 1.20 github.com/tailscale/wireguard-go/conn github.com/tailscale/wireguard-go/conn/winrio @@ -687,7 +707,7 @@ github.com/yiya1989/sshkrb5/krb5forssh # go4.org/mem v0.0.0-20220726221520-4f986261bf13 ## explicit; go 1.14 go4.org/mem -# go4.org/netipx v0.0.0-20230303233057-f1b76eb4bb35 +# go4.org/netipx v0.0.0-20230728180743-ad4cb58a6516 ## explicit; go 1.18 go4.org/netipx # golang.org/x/crypto v0.14.0 @@ -722,7 +742,7 @@ golang.org/x/crypto/scrypt golang.org/x/crypto/ssh golang.org/x/crypto/ssh/agent golang.org/x/crypto/ssh/internal/bcrypt_pbkdf -# golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 +# golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 ## explicit; go 1.20 golang.org/x/exp/constraints golang.org/x/exp/maps @@ -730,7 +750,7 @@ golang.org/x/exp/slices golang.org/x/exp/slog golang.org/x/exp/slog/internal golang.org/x/exp/slog/internal/buffer -# golang.org/x/mod v0.10.0 +# golang.org/x/mod v0.11.0 ## explicit; go 1.17 golang.org/x/mod/semver # golang.org/x/net v0.17.0 @@ -1104,12 +1124,14 @@ nhooyr.io/websocket/internal/bpool nhooyr.io/websocket/internal/errd nhooyr.io/websocket/internal/wsjs nhooyr.io/websocket/internal/xsync -# tailscale.com v1.44.0 -## explicit; go 1.20 +# tailscale.com v1.50.1 +## explicit; go 1.21 tailscale.com tailscale.com/atomicfile tailscale.com/client/tailscale tailscale.com/client/tailscale/apitype +tailscale.com/clientupdate +tailscale.com/clientupdate/distsign tailscale.com/control/controlbase tailscale.com/control/controlclient tailscale.com/control/controlhttp @@ -1136,7 +1158,6 @@ tailscale.com/ipn/store/kubestore tailscale.com/ipn/store/mem tailscale.com/kube tailscale.com/log/filelogger -tailscale.com/log/logheap tailscale.com/log/sockstatlog tailscale.com/logpolicy tailscale.com/logtail @@ -1146,6 +1167,7 @@ tailscale.com/metrics tailscale.com/net/connstats tailscale.com/net/dns tailscale.com/net/dns/publicdns +tailscale.com/net/dns/recursive tailscale.com/net/dns/resolvconffile tailscale.com/net/dns/resolver tailscale.com/net/dnscache @@ -1169,6 +1191,7 @@ tailscale.com/net/routetable tailscale.com/net/socks5 tailscale.com/net/sockstats tailscale.com/net/stun +tailscale.com/net/tcpinfo tailscale.com/net/tlsdial tailscale.com/net/tsaddr tailscale.com/net/tsdial @@ -1178,11 +1201,13 @@ tailscale.com/net/tstun/table tailscale.com/net/wsconn tailscale.com/paths tailscale.com/portlist +tailscale.com/proxymap tailscale.com/safesocket tailscale.com/smallzstd tailscale.com/syncs tailscale.com/tailcfg tailscale.com/tempfork/device +tailscale.com/tempfork/heap tailscale.com/tka tailscale.com/tsconst tailscale.com/tsd @@ -1219,20 +1244,26 @@ tailscale.com/util/groupmember tailscale.com/util/hashx tailscale.com/util/httpm tailscale.com/util/lineread +tailscale.com/util/linuxfw tailscale.com/util/mak tailscale.com/util/multierr tailscale.com/util/must +tailscale.com/util/osdiag +tailscale.com/util/osdiag/internal/wsc tailscale.com/util/osshare tailscale.com/util/pidowner tailscale.com/util/racebuild +tailscale.com/util/rands tailscale.com/util/ringbuffer tailscale.com/util/set tailscale.com/util/singleflight tailscale.com/util/slicesx tailscale.com/util/sysresources tailscale.com/util/systemd +tailscale.com/util/testenv tailscale.com/util/uniq tailscale.com/util/winutil +tailscale.com/util/winutil/authenticode tailscale.com/util/winutil/policy tailscale.com/version tailscale.com/version/distro diff --git a/vendor/tailscale.com/.gitignore b/vendor/tailscale.com/.gitignore index a613c538de..bea5627bc1 100644 --- a/vendor/tailscale.com/.gitignore +++ b/vendor/tailscale.com/.gitignore @@ -35,5 +35,10 @@ cmd/tailscaled/tailscaled # Ignore direnv nix-shell environment cache .direnv/ +# Ignore web client node modules +.vite/ +client/web/node_modules +client/web/build/assets + /gocross /dist diff --git a/vendor/tailscale.com/CODEOWNERS b/vendor/tailscale.com/CODEOWNERS new file mode 100644 index 0000000000..af9b0d9f95 --- /dev/null +++ b/vendor/tailscale.com/CODEOWNERS @@ -0,0 +1 @@ +/tailcfg/ @tailscale/control-protocol-owners diff --git a/vendor/tailscale.com/Dockerfile b/vendor/tailscale.com/Dockerfile index 14d5d06677..09a05f731e 100644 --- a/vendor/tailscale.com/Dockerfile +++ b/vendor/tailscale.com/Dockerfile @@ -31,7 +31,7 @@ # $ docker exec tailscaled tailscale status -FROM golang:1.20-alpine AS build-env +FROM golang:1.21-alpine AS build-env WORKDIR /go/src/tailscale diff --git a/vendor/tailscale.com/Makefile b/vendor/tailscale.com/Makefile index 4a9acb0b45..166fa19487 100644 --- a/vendor/tailscale.com/Makefile +++ b/vendor/tailscale.com/Makefile @@ -1,6 +1,7 @@ IMAGE_REPO ?= tailscale/tailscale SYNO_ARCH ?= "amd64" SYNO_DSM ?= "7" +TAGS ?= "latest" vet: ## Run go vet ./tool/go vet ./... @@ -36,6 +37,9 @@ buildlinuxarm: ## Build tailscale CLI for linux/arm buildwasm: ## Build tailscale CLI for js/wasm GOOS=js GOARCH=wasm ./tool/go install ./cmd/tsconnect/wasm ./cmd/tailscale/cli +buildplan9: + GOOS=plan9 GOARCH=amd64 ./tool/go install ./cmd/tailscale ./cmd/tailscaled + buildlinuxloong64: ## Build tailscale CLI for linux/loong64 GOOS=linux GOARCH=loong64 ./tool/go install tailscale.com/cmd/tailscale tailscale.com/cmd/tailscaled @@ -64,7 +68,7 @@ publishdevimage: ## Build and publish tailscale image to location specified by $ @test "${REPO}" != "ghcr.io/tailscale/tailscale" || (echo "REPO=... must not be ghcr.io/tailscale/tailscale" && exit 1) @test "${REPO}" != "tailscale/k8s-operator" || (echo "REPO=... must not be tailscale/k8s-operator" && exit 1) @test "${REPO}" != "ghcr.io/tailscale/k8s-operator" || (echo "REPO=... must not be ghcr.io/tailscale/k8s-operator" && exit 1) - TAGS=latest REPOS=${REPO} PUSH=true TARGET=client ./build_docker.sh + TAGS="${TAGS}" REPOS=${REPO} PUSH=true TARGET=client ./build_docker.sh publishdevoperator: ## Build and publish k8s-operator image to location specified by ${REPO} @test -n "${REPO}" || (echo "REPO=... required; e.g. REPO=ghcr.io/${USER}/tailscale" && exit 1) @@ -72,7 +76,7 @@ publishdevoperator: ## Build and publish k8s-operator image to location specifie @test "${REPO}" != "ghcr.io/tailscale/tailscale" || (echo "REPO=... must not be ghcr.io/tailscale/tailscale" && exit 1) @test "${REPO}" != "tailscale/k8s-operator" || (echo "REPO=... must not be tailscale/k8s-operator" && exit 1) @test "${REPO}" != "ghcr.io/tailscale/k8s-operator" || (echo "REPO=... must not be ghcr.io/tailscale/k8s-operator" && exit 1) - TAGS=latest REPOS=${REPO} PUSH=true TARGET=operator ./build_docker.sh + TAGS="${TAGS}" REPOS=${REPO} PUSH=true TARGET=operator ./build_docker.sh help: ## Show this help @echo "\nSpecify a command. The choices are:\n" diff --git a/vendor/tailscale.com/README.md b/vendor/tailscale.com/README.md index a76a0c3265..0eae446245 100644 --- a/vendor/tailscale.com/README.md +++ b/vendor/tailscale.com/README.md @@ -37,7 +37,7 @@ not open source. ## Building -We always require the latest Go release, currently Go 1.20. (While we build +We always require the latest Go release, currently Go 1.21. (While we build releases with our [Go fork](https://github.com/tailscale/go/), its use is not required.) diff --git a/vendor/tailscale.com/VERSION.txt b/vendor/tailscale.com/VERSION.txt index 372cf402c7..9cbd34da1a 100644 --- a/vendor/tailscale.com/VERSION.txt +++ b/vendor/tailscale.com/VERSION.txt @@ -1 +1 @@ -1.44.0 +1.50.1 diff --git a/vendor/tailscale.com/api.md b/vendor/tailscale.com/api.md index 028f71365f..51b75719d8 100644 --- a/vendor/tailscale.com/api.md +++ b/vendor/tailscale.com/api.md @@ -101,8 +101,8 @@ You can also [list all devices in the tailnet](#list-tailnet-devices) to get the ``` jsonc { // addresses (array of strings) is a list of Tailscale IP - // addresses for the device, including both ipv4 (formatted as 100.x.y.z) - // and ipv6 (formatted as fd7a:115c:a1e0:a:b:c:d:e) addresses. + // addresses for the device, including both IPv4 (formatted as 100.x.y.z) + // and IPv6 (formatted as fd7a:115c:a1e0:a:b:c:d:e) addresses. "addresses": [ "100.87.74.78", "fd7a:115c:a1e0:ac82:4843:ca90:697d:c36e" @@ -516,7 +516,8 @@ The ID of the device. #### `authorized` (required in `POST` body) -Specify whether the device is authorized. +Specify whether the device is authorized. False to deauthorize an authorized device, and true to authorize a new device or to re-authorize a previously deauthorized device. + ``` jsonc { @@ -1114,6 +1115,21 @@ Look at the response body to determine whether there was a problem within your A } ``` +If your tailnet has [user and group provisioning](https://tailscale.com/kb/1180/sso-okta-scim/) turned on, we will also warn you about +any groups that are used in the policy file that are not being synced from SCIM. Explicitly defined groups will not trigger this warning. + +```jsonc +{ + "message":"warning(s) found", + "data":[ + { + "user": "group:unknown@example.com", + "warnings":["group is not syncing from SCIM and will be ignored by rules in the policy file"] + } + ] +} +``` + ## List tailnet devices @@ -1222,6 +1238,11 @@ The remaining three methods operate on auth keys and API access tokens. // expirySeconds (int) is the duration in seconds a new key is valid. "expirySeconds": 86400 + + // description (string) is an optional short phrase that describes what + // this key is used for. It can be a maximum of 50 alphanumeric characters. + // Hyphens and underscores are also allowed. + "description": "short description of key purpose" } ``` @@ -1308,6 +1329,9 @@ Note the following about required vs. optional values: Specifies the duration in seconds until the key should expire. Defaults to 90 days if not supplied. +- **`description`:** Optional in `POST` body. + A short string specifying the purpose of the key. Can be a maximum of 50 alphanumeric characters. Hyphens and spaces are also allowed. + ### Request example ``` jsonc @@ -1325,7 +1349,8 @@ curl "https://api.tailscale.com/api/v2/tailnet/example.com/keys" \ } } }, - "expirySeconds": 86400 + "expirySeconds": 86400, + "description": "dev access" }' ``` @@ -1351,7 +1376,8 @@ It holds the capabilities specified in the request and can no longer be retrieve "tags": [ "tag:example" ] } } - } + }, + "description": "dev access" } ``` @@ -1403,7 +1429,20 @@ The response is a JSON object with information about the key supplied. ] } } - } + }, + "description": "dev access" +} +``` + +Response for a revoked (deleted) or expired key will have an `invalid` field set to `true`: + +``` jsonc +{ + "id": "abc123456CNTRL", + "created": "2022-05-05T18:55:44Z", + "expires": "2022-08-03T18:55:44Z", + "revoked": "2023-04-01T20:50:00Z", + "invalid": true } ``` diff --git a/vendor/tailscale.com/client/tailscale/acl.go b/vendor/tailscale.com/client/tailscale/acl.go index 226e7f9165..ddc79506d7 100644 --- a/vendor/tailscale.com/client/tailscale/acl.go +++ b/vendor/tailscale.com/client/tailscale/acl.go @@ -150,8 +150,9 @@ func (c *Client) ACLHuJSON(ctx context.Context) (acl *ACLHuJSON, err error) { // ACLTestFailureSummary specifies the JSON format sent to the // JavaScript client to be rendered in the HTML. type ACLTestFailureSummary struct { - User string `json:"user"` - Errors []string `json:"errors"` + User string `json:"user,omitempty"` + Errors []string `json:"errors,omitempty"` + Warnings []string `json:"warnings,omitempty"` } // ACLTestError is ErrResponse but with an extra field to account for ACLTestFailureSummary. diff --git a/vendor/tailscale.com/client/tailscale/apitype/apitype.go b/vendor/tailscale.com/client/tailscale/apitype/apitype.go index e4c4e538f8..b63abf69c9 100644 --- a/vendor/tailscale.com/client/tailscale/apitype/apitype.go +++ b/vendor/tailscale.com/client/tailscale/apitype/apitype.go @@ -10,12 +10,14 @@ import "tailscale.com/tailcfg" const LocalAPIHost = "local-tailscaled.sock" // WhoIsResponse is the JSON type returned by tailscaled debug server's /whois?ip=$IP handler. +// In successful whois responses, Node and UserProfile are never nil. type WhoIsResponse struct { Node *tailcfg.Node UserProfile *tailcfg.UserProfile - // Caps are extra capabilities that the remote Node has to this node. - Caps []string `json:",omitempty"` + // CapMap is a map of capabilities to their values. + // See tailcfg.PeerCapMap and tailcfg.PeerCapability for details. + CapMap tailcfg.PeerCapMap } // FileTarget is a node to which files can be sent, and the PeerAPI diff --git a/vendor/tailscale.com/client/tailscale/apitype/controltype.go b/vendor/tailscale.com/client/tailscale/apitype/controltype.go index 9283288622..9a623be319 100644 --- a/vendor/tailscale.com/client/tailscale/apitype/controltype.go +++ b/vendor/tailscale.com/client/tailscale/apitype/controltype.go @@ -4,12 +4,13 @@ package apitype type DNSConfig struct { - Resolvers []DNSResolver `json:"resolvers"` - FallbackResolvers []DNSResolver `json:"fallbackResolvers"` - Routes map[string][]DNSResolver `json:"routes"` - Domains []string `json:"domains"` - Nameservers []string `json:"nameservers"` - Proxied bool `json:"proxied"` + Resolvers []DNSResolver `json:"resolvers"` + FallbackResolvers []DNSResolver `json:"fallbackResolvers"` + Routes map[string][]DNSResolver `json:"routes"` + Domains []string `json:"domains"` + Nameservers []string `json:"nameservers"` + Proxied bool `json:"proxied"` + TempCorpIssue13969 string `json:"TempCorpIssue13969,omitempty"` } type DNSResolver struct { diff --git a/vendor/tailscale.com/client/tailscale/localclient.go b/vendor/tailscale.com/client/tailscale/localclient.go index 611a49b297..32cb1041b1 100644 --- a/vendor/tailscale.com/client/tailscale/localclient.go +++ b/vendor/tailscale.com/client/tailscale/localclient.go @@ -37,6 +37,7 @@ import ( "tailscale.com/tka" "tailscale.com/types/key" "tailscale.com/types/tkatype" + "tailscale.com/util/cmpx" ) // defaultLocalClient is the default LocalClient when using the legacy @@ -139,6 +140,10 @@ func (lc *LocalClient) doLocalRequestNiceError(req *http.Request) (*http.Respons all, _ := io.ReadAll(res.Body) return nil, &AccessDeniedError{errors.New(errorMessageFromBody(all))} } + if res.StatusCode == http.StatusPreconditionFailed { + all, _ := io.ReadAll(res.Body) + return nil, &PreconditionsFailedError{errors.New(errorMessageFromBody(all))} + } return res, nil } if ue, ok := err.(*url.Error); ok { @@ -169,6 +174,24 @@ func IsAccessDeniedError(err error) bool { return errors.As(err, &ae) } +// PreconditionsFailedError is returned when the server responds +// with an HTTP 412 status code. +type PreconditionsFailedError struct { + err error +} + +func (e *PreconditionsFailedError) Error() string { + return fmt.Sprintf("Preconditions failed: %v", e.err) +} + +func (e *PreconditionsFailedError) Unwrap() error { return e.err } + +// IsPreconditionsFailedError reports whether err is or wraps an PreconditionsFailedError. +func IsPreconditionsFailedError(err error) bool { + var ae *PreconditionsFailedError + return errors.As(err, &ae) +} + // bestError returns either err, or if body contains a valid JSON // object of type errorJSON, its non-empty error body. func bestError(err error, body []byte) error { @@ -197,27 +220,42 @@ func SetVersionMismatchHandler(f func(clientVer, serverVer string)) { } func (lc *LocalClient) send(ctx context.Context, method, path string, wantStatus int, body io.Reader) ([]byte, error) { + slurp, _, err := lc.sendWithHeaders(ctx, method, path, wantStatus, body, nil) + return slurp, err +} + +func (lc *LocalClient) sendWithHeaders( + ctx context.Context, + method, + path string, + wantStatus int, + body io.Reader, + h http.Header, +) ([]byte, http.Header, error) { if jr, ok := body.(jsonReader); ok && jr.err != nil { - return nil, jr.err // fail early if there was a JSON marshaling error + return nil, nil, jr.err // fail early if there was a JSON marshaling error } req, err := http.NewRequestWithContext(ctx, method, "http://"+apitype.LocalAPIHost+path, body) if err != nil { - return nil, err + return nil, nil, err + } + if h != nil { + req.Header = h } res, err := lc.doLocalRequestNiceError(req) if err != nil { - return nil, err + return nil, nil, err } defer res.Body.Close() slurp, err := io.ReadAll(res.Body) if err != nil { - return nil, err + return nil, nil, err } if res.StatusCode != wantStatus { err = fmt.Errorf("%v: %s", res.Status, bytes.TrimSpace(slurp)) - return nil, bestError(err, slurp) + return nil, nil, bestError(err, slurp) } - return slurp, nil + return slurp, res.Header, nil } func (lc *LocalClient) get200(ctx context.Context, path string) ([]byte, error) { @@ -259,6 +297,28 @@ func (lc *LocalClient) DaemonMetrics(ctx context.Context) ([]byte, error) { return lc.get200(ctx, "/localapi/v0/metrics") } +// IncrementCounter increments the value of a Tailscale daemon's counter +// metric by the given delta. If the metric has yet to exist, a new counter +// metric is created and initialized to delta. +// +// IncrementCounter does not support gauge metrics or negative delta values. +func (lc *LocalClient) IncrementCounter(ctx context.Context, name string, delta int) error { + type metricUpdate struct { + Name string `json:"name"` + Type string `json:"type"` + Value int `json:"value"` // amount to increment by + } + if delta < 0 { + return errors.New("negative delta not allowed") + } + _, err := lc.send(ctx, "POST", "/localapi/v0/upload-client-metrics", 200, jsonBody([]metricUpdate{{ + Name: name, + Type: "counter", + Value: delta, + }})) + return err +} + // TailDaemonLogs returns a stream the Tailscale daemon's logs as they arrive. // Close the context to stop the stream. func (lc *LocalClient) TailDaemonLogs(ctx context.Context) (io.Reader, error) { @@ -369,15 +429,65 @@ func (lc *LocalClient) DebugAction(ctx context.Context, action string) error { return nil } +// DebugResultJSON invokes a debug action and returns its result as something JSON-able. +// These are development tools and subject to change or removal over time. +func (lc *LocalClient) DebugResultJSON(ctx context.Context, action string) (any, error) { + body, err := lc.send(ctx, "POST", "/localapi/v0/debug?action="+url.QueryEscape(action), 200, nil) + if err != nil { + return nil, fmt.Errorf("error %w: %s", err, body) + } + var x any + if err := json.Unmarshal(body, &x); err != nil { + return nil, err + } + return x, nil +} + +// DebugPortmapOpts contains options for the DebugPortmap command. +type DebugPortmapOpts struct { + // Duration is how long the mapping should be created for. It defaults + // to 5 seconds if not set. + Duration time.Duration + + // Type is the kind of portmap to debug. The empty string instructs the + // portmap client to perform all known types. Other valid options are + // "pmp", "pcp", and "upnp". + Type string + + // GatewayAddr specifies the gateway address used during portmapping. + // If set, SelfAddr must also be set. If unset, it will be + // autodetected. + GatewayAddr netip.Addr + + // SelfAddr specifies the gateway address used during portmapping. If + // set, GatewayAddr must also be set. If unset, it will be + // autodetected. + SelfAddr netip.Addr + + // LogHTTP instructs the debug-portmap endpoint to print all HTTP + // requests and responses made to the logs. + LogHTTP bool +} + // DebugPortmap invokes the debug-portmap endpoint, and returns an // io.ReadCloser that can be used to read the logs that are printed during this // process. -func (lc *LocalClient) DebugPortmap(ctx context.Context, duration time.Duration, ty, gwSelf string) (io.ReadCloser, error) { +// +// opts can be nil; if so, default values will be used. +func (lc *LocalClient) DebugPortmap(ctx context.Context, opts *DebugPortmapOpts) (io.ReadCloser, error) { vals := make(url.Values) - vals.Set("duration", duration.String()) - vals.Set("type", ty) - if gwSelf != "" { - vals.Set("gateway_and_self", gwSelf) + if opts == nil { + opts = &DebugPortmapOpts{} + } + + vals.Set("duration", cmpx.Or(opts.Duration, 5*time.Second).String()) + vals.Set("type", opts.Type) + vals.Set("log_http", strconv.FormatBool(opts.LogHTTP)) + + if opts.GatewayAddr.IsValid() != opts.SelfAddr.IsValid() { + return nil, fmt.Errorf("both GatewayAddr and SelfAddr must be provided if one is") + } else if opts.GatewayAddr.IsValid() { + vals.Set("gateway_and_self", fmt.Sprintf("%s/%s", opts.GatewayAddr, opts.SelfAddr)) } req, err := http.NewRequestWithContext(ctx, "GET", "http://"+apitype.LocalAPIHost+"/localapi/v0/debug-portmap?"+vals.Encode(), nil) @@ -807,11 +917,25 @@ func (lc *LocalClient) ExpandSNIName(ctx context.Context, name string) (fqdn str return "", false } +// PingOpts contains options for the ping request. +// +// The zero value is valid, which means to use defaults. +type PingOpts struct { + // Size is the length of the ping message in bytes. It's ignored if it's + // smaller than the minimum message size. + // + // For disco pings, it specifies the length of the packet's payload. That + // is, it includes the disco headers and message, but not the IP and UDP + // headers. + Size int +} + // Ping sends a ping of the provided type to the provided IP and waits -// for its response. -func (lc *LocalClient) Ping(ctx context.Context, ip netip.Addr, pingtype tailcfg.PingType) (*ipnstate.PingResult, error) { +// for its response. The opts type specifies additional options. +func (lc *LocalClient) PingWithOpts(ctx context.Context, ip netip.Addr, pingtype tailcfg.PingType, opts PingOpts) (*ipnstate.PingResult, error) { v := url.Values{} v.Set("ip", ip.String()) + v.Set("size", strconv.Itoa(opts.Size)) v.Set("type", string(pingtype)) body, err := lc.send(ctx, "POST", "/localapi/v0/ping?"+v.Encode(), 200, nil) if err != nil { @@ -820,6 +944,12 @@ func (lc *LocalClient) Ping(ctx context.Context, ip netip.Addr, pingtype tailcfg return decodeJSON[*ipnstate.PingResult](body) } +// Ping sends a ping of the provided type to the provided IP and waits +// for its response. +func (lc *LocalClient) Ping(ctx context.Context, ip netip.Addr, pingtype tailcfg.PingType) (*ipnstate.PingResult, error) { + return lc.PingWithOpts(ctx, ip, pingtype, PingOpts{}) +} + // NetworkLockStatus fetches information about the tailnet key authority, if one is configured. func (lc *LocalClient) NetworkLockStatus(ctx context.Context) (*ipnstate.NetworkLockStatus, error) { body, err := lc.send(ctx, "GET", "/localapi/v0/tka/status", 200, nil) @@ -946,10 +1076,65 @@ func (lc *LocalClient) NetworkLockForceLocalDisable(ctx context.Context) error { return nil } +// NetworkLockVerifySigningDeeplink verifies the network lock deeplink contained +// in url and returns information extracted from it. +func (lc *LocalClient) NetworkLockVerifySigningDeeplink(ctx context.Context, url string) (*tka.DeeplinkValidationResult, error) { + vr := struct { + URL string + }{url} + + body, err := lc.send(ctx, "POST", "/localapi/v0/tka/verify-deeplink", 200, jsonBody(vr)) + if err != nil { + return nil, fmt.Errorf("sending verify-deeplink: %w", err) + } + + return decodeJSON[*tka.DeeplinkValidationResult](body) +} + +// NetworkLockGenRecoveryAUM generates an AUM for recovering from a tailnet-lock key compromise. +func (lc *LocalClient) NetworkLockGenRecoveryAUM(ctx context.Context, removeKeys []tkatype.KeyID, forkFrom tka.AUMHash) ([]byte, error) { + vr := struct { + Keys []tkatype.KeyID + ForkFrom string + }{removeKeys, forkFrom.String()} + + body, err := lc.send(ctx, "POST", "/localapi/v0/tka/generate-recovery-aum", 200, jsonBody(vr)) + if err != nil { + return nil, fmt.Errorf("sending generate-recovery-aum: %w", err) + } + + return body, nil +} + +// NetworkLockCosignRecoveryAUM co-signs a recovery AUM using the node's tailnet lock key. +func (lc *LocalClient) NetworkLockCosignRecoveryAUM(ctx context.Context, aum tka.AUM) ([]byte, error) { + r := bytes.NewReader(aum.Serialize()) + body, err := lc.send(ctx, "POST", "/localapi/v0/tka/cosign-recovery-aum", 200, r) + if err != nil { + return nil, fmt.Errorf("sending cosign-recovery-aum: %w", err) + } + + return body, nil +} + +// NetworkLockSubmitRecoveryAUM submits a recovery AUM to the control plane. +func (lc *LocalClient) NetworkLockSubmitRecoveryAUM(ctx context.Context, aum tka.AUM) error { + r := bytes.NewReader(aum.Serialize()) + _, err := lc.send(ctx, "POST", "/localapi/v0/tka/submit-recovery-aum", 200, r) + if err != nil { + return fmt.Errorf("sending cosign-recovery-aum: %w", err) + } + return nil +} + // SetServeConfig sets or replaces the serving settings. // If config is nil, settings are cleared and serving is disabled. func (lc *LocalClient) SetServeConfig(ctx context.Context, config *ipn.ServeConfig) error { - _, err := lc.send(ctx, "POST", "/localapi/v0/serve-config", 200, jsonBody(config)) + h := make(http.Header) + if config != nil { + h.Set("If-Match", config.ETag) + } + _, _, err := lc.sendWithHeaders(ctx, "POST", "/localapi/v0/serve-config", 200, jsonBody(config), h) if err != nil { return fmt.Errorf("sending serve config: %w", err) } @@ -968,11 +1153,19 @@ func (lc *LocalClient) NetworkLockDisable(ctx context.Context, secret []byte) er // // If the serve config is empty, it returns (nil, nil). func (lc *LocalClient) GetServeConfig(ctx context.Context) (*ipn.ServeConfig, error) { - body, err := lc.send(ctx, "GET", "/localapi/v0/serve-config", 200, nil) + body, h, err := lc.sendWithHeaders(ctx, "GET", "/localapi/v0/serve-config", 200, nil, nil) if err != nil { return nil, fmt.Errorf("getting serve config: %w", err) } - return getServeConfigFromJSON(body) + sc, err := getServeConfigFromJSON(body) + if err != nil { + return nil, err + } + if sc == nil { + sc = new(ipn.ServeConfig) + } + sc.ETag = h.Get("Etag") + return sc, nil } func getServeConfigFromJSON(body []byte) (sc *ipn.ServeConfig, err error) { @@ -1073,6 +1266,27 @@ func (lc *LocalClient) DeleteProfile(ctx context.Context, profile ipn.ProfileID) return err } +// QueryFeature makes a request for instructions on how to enable +// a feature, such as Funnel, for the node's tailnet. If relevant, +// this includes a control server URL the user can visit to enable +// the feature. +// +// If you are looking to use QueryFeature, you'll likely want to +// use cli.enableFeatureInteractive instead, which handles the logic +// of wraping QueryFeature and translating its response into an +// interactive flow for the user, including using the IPN notify bus +// to block until the feature has been enabled. +// +// 2023-08-09: Valid feature values are "serve" and "funnel". +func (lc *LocalClient) QueryFeature(ctx context.Context, feature string) (*tailcfg.QueryFeatureResponse, error) { + v := url.Values{"feature": {feature}} + body, err := lc.send(ctx, "POST", "/localapi/v0/query-feature?"+v.Encode(), 200, nil) + if err != nil { + return nil, fmt.Errorf("error %w: %s", err, body) + } + return decodeJSON[*tailcfg.QueryFeatureResponse](body) +} + func (lc *LocalClient) DebugDERPRegion(ctx context.Context, regionIDOrCode string) (*ipnstate.DebugDERPRegionReport, error) { v := url.Values{"region": {regionIDOrCode}} body, err := lc.send(ctx, "POST", "/localapi/v0/debug-derp-region?"+v.Encode(), 200, nil) diff --git a/vendor/tailscale.com/client/tailscale/required_version.go b/vendor/tailscale.com/client/tailscale/required_version.go index 4b44bf2702..ff15fc78a0 100644 --- a/vendor/tailscale.com/client/tailscale/required_version.go +++ b/vendor/tailscale.com/client/tailscale/required_version.go @@ -1,10 +1,10 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !go1.20 +//go:build !go1.21 package tailscale func init() { - you_need_Go_1_20_to_compile_Tailscale() + you_need_Go_1_21_to_compile_Tailscale() } diff --git a/vendor/tailscale.com/clientupdate/clientupdate.go b/vendor/tailscale.com/clientupdate/clientupdate.go new file mode 100644 index 0000000000..b788c73212 --- /dev/null +++ b/vendor/tailscale.com/clientupdate/clientupdate.go @@ -0,0 +1,1094 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package clientupdate implements tailscale client update for all supported +// platforms. This package can be used from both tailscaled and tailscale +// binaries. +package clientupdate + +import ( + "archive/tar" + "bufio" + "bytes" + "compress/gzip" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "maps" + "net/http" + "os" + "os/exec" + "path" + "path/filepath" + "regexp" + "runtime" + "strconv" + "strings" + + "github.com/google/uuid" + "tailscale.com/clientupdate/distsign" + "tailscale.com/types/logger" + "tailscale.com/util/winutil" + "tailscale.com/version" + "tailscale.com/version/distro" +) + +const ( + CurrentTrack = "" + StableTrack = "stable" + UnstableTrack = "unstable" +) + +func versionToTrack(v string) (string, error) { + _, rest, ok := strings.Cut(v, ".") + if !ok { + return "", fmt.Errorf("malformed version %q", v) + } + minorStr, _, ok := strings.Cut(rest, ".") + if !ok { + return "", fmt.Errorf("malformed version %q", v) + } + minor, err := strconv.Atoi(minorStr) + if err != nil { + return "", fmt.Errorf("malformed version %q", v) + } + if minor%2 == 0 { + return "stable", nil + } + return "unstable", nil +} + +// Arguments contains arguments needed to run an update. +type Arguments struct { + // Version can be a specific version number or one of the predefined track + // constants: + // + // - CurrentTrack will use the latest version from the same track as the + // running binary + // - StableTrack and UnstableTrack will use the latest versions of the + // corresponding tracks + // + // Leaving this empty is the same as using CurrentTrack. + Version string + // AppStore forces a local app store check, even if the current binary was + // not installed via an app store. TODO(cpalmer): Remove this. + AppStore bool + // Logf is a logger for update progress messages. + Logf logger.Logf + // Confirm is called when a new version is available and should return true + // if this new version should be installed. When Confirm returns false, the + // update is aborted. + Confirm func(newVer string) bool + // PkgsAddr is the address of the pkgs server to fetch updates from. + // Defaults to "https://pkgs.tailscale.com". + PkgsAddr string +} + +func (args Arguments) validate() error { + if args.Confirm == nil { + return errors.New("missing Confirm callback in Arguments") + } + if args.Logf == nil { + return errors.New("missing Logf callback in Arguments") + } + return nil +} + +type Updater struct { + Arguments + track string + // Update is a platform-specific method that updates the installation. May be + // nil (not all platforms support updates from within Tailscale). + Update func() error +} + +func NewUpdater(args Arguments) (*Updater, error) { + up := Updater{ + Arguments: args, + } + up.Update = up.getUpdateFunction() + if up.Update == nil { + return nil, errors.ErrUnsupported + } + switch up.Version { + case StableTrack, UnstableTrack: + up.track = up.Version + case CurrentTrack: + if version.IsUnstableBuild() { + up.track = UnstableTrack + } else { + up.track = StableTrack + } + default: + var err error + up.track, err = versionToTrack(args.Version) + if err != nil { + return nil, err + } + } + if up.Arguments.PkgsAddr == "" { + up.Arguments.PkgsAddr = "https://pkgs.tailscale.com" + } + return &up, nil +} + +type updateFunction func() error + +func (up *Updater) getUpdateFunction() updateFunction { + switch runtime.GOOS { + case "windows": + return up.updateWindows + case "linux": + switch distro.Get() { + case distro.Synology: + return up.updateSynology + case distro.Debian: // includes Ubuntu + return up.updateDebLike + case distro.Arch: + return up.updateArchLike + case distro.Alpine: + return up.updateAlpineLike + } + switch { + case haveExecutable("pacman"): + return up.updateArchLike + case haveExecutable("apt-get"): // TODO(awly): add support for "apt" + // The distro.Debian switch case above should catch most apt-based + // systems, but add this fallback just in case. + return up.updateDebLike + case haveExecutable("dnf"): + return up.updateFedoraLike("dnf") + case haveExecutable("yum"): + return up.updateFedoraLike("yum") + case haveExecutable("apk"): + return up.updateAlpineLike + } + // If nothing matched, fall back to tarball updates. + if up.Update == nil { + return up.updateLinuxBinary + } + case "darwin": + switch { + case !up.Arguments.AppStore && !version.IsSandboxedMacOS(): + return nil + case !up.Arguments.AppStore && strings.HasSuffix(os.Getenv("HOME"), "/io.tailscale.ipn.macsys/Data"): + return up.updateMacSys + default: + return up.updateMacAppStore + } + case "freebsd": + return up.updateFreeBSD + } + return nil +} + +// Update runs a single update attempt using the platform-specific mechanism. +// +// On Windows, this copies the calling binary and re-executes it to apply the +// update. The calling binary should handle an "update" subcommand and call +// this function again for the re-executed binary to proceed. +func Update(args Arguments) error { + if err := args.validate(); err != nil { + return err + } + up, err := NewUpdater(args) + if err != nil { + return err + } + return up.Update() +} + +func (up *Updater) confirm(ver string) bool { + if version.Short() == ver { + up.Logf("already running %v; no update needed", ver) + return false + } + if up.Confirm != nil { + return up.Confirm(ver) + } + return true +} + +const synoinfoConfPath = "/etc/synoinfo.conf" + +func (up *Updater) updateSynology() error { + if up.Version != "" { + return errors.New("installing a specific version on Synology is not supported") + } + + // Get the latest version and list of SPKs from pkgs.tailscale.com. + dsmVersion := distro.DSMVersion() + osName := fmt.Sprintf("dsm%d", dsmVersion) + arch, err := synoArch(runtime.GOARCH, synoinfoConfPath) + if err != nil { + return err + } + latest, err := latestPackages(up.track) + if err != nil { + return err + } + spkName := latest.SPKs[osName][arch] + if spkName == "" { + return fmt.Errorf("cannot find Synology package for os=%s arch=%s, please report a bug with your device model", osName, arch) + } + + if !up.confirm(latest.SPKsVersion) { + return nil + } + if err := requireRoot(); err != nil { + return err + } + + // Download the SPK into a temporary directory. + spkDir, err := os.MkdirTemp("", "tailscale-update") + if err != nil { + return err + } + pkgsPath := fmt.Sprintf("%s/%s", up.track, spkName) + spkPath := filepath.Join(spkDir, path.Base(pkgsPath)) + if err := up.downloadURLToFile(pkgsPath, spkPath); err != nil { + return err + } + + // Install the SPK. Run via nohup to allow install to succeed when we're + // connected over tailscale ssh and this parent process dies. Otherwise, if + // you abort synopkg install mid-way, tailscaled is not restarted. + cmd := exec.Command("nohup", "synopkg", "install", spkPath) + // Don't attach cmd.Stdout to os.Stdout because nohup will redirect that + // into nohup.out file. synopkg doesn't have any progress output anyway, it + // just spits out a JSON result when done. + out, err := cmd.CombinedOutput() + if err != nil { + if dsmVersion == 6 && bytes.Contains(out, []byte("error = [290]")) { + return fmt.Errorf("synopkg install failed: %w\noutput:\n%s\nplease make sure that packages from 'Any publisher' are allowed in the Package Center (Package Center -> Settings -> Trust Level -> Any publisher)", err, out) + } + return fmt.Errorf("synopkg install failed: %w\noutput:\n%s", err, out) + } + if dsmVersion == 6 { + // DSM6 does not automatically restart the package on install. Do it + // manually. + cmd := exec.Command("nohup", "synopkg", "start", "Tailscale") + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("synopkg start failed: %w\noutput:\n%s", err, out) + } + } + return nil +} + +// synoArch returns the Synology CPU architecture matching one of the SPK +// architectures served from pkgs.tailscale.com. +func synoArch(goArch, synoinfoPath string) (string, error) { + // Most Synology boxes just use a different arch name from GOARCH. + arch := map[string]string{ + "amd64": "x86_64", + "386": "i686", + "arm64": "armv8", + }[goArch] + + if arch == "" { + // Here's the fun part, some older ARM boxes require you to use SPKs + // specifically for their CPU. See + // https://github.com/SynoCommunity/spksrc/wiki/Synology-and-SynoCommunity-Package-Architectures + // for a complete list. + // + // Some CPUs will map to neither this list nor the goArch map above, and we + // don't have SPKs for them. + cpu, err := parseSynoinfo(synoinfoPath) + if err != nil { + return "", fmt.Errorf("failed to get CPU architecture: %w", err) + } + switch cpu { + case "88f6281", "88f6282", "hi3535", "alpine", "armada370", + "armada375", "armada38x", "armadaxp", "comcerto2k", "monaco": + arch = cpu + default: + return "", fmt.Errorf("unsupported Synology CPU architecture %q (Go arch %q), please report a bug at https://github.com/tailscale/tailscale/issues/new/choose", cpu, goArch) + } + } + return arch, nil +} + +func parseSynoinfo(path string) (string, error) { + f, err := os.Open(path) + if err != nil { + return "", err + } + defer f.Close() + + // Look for a line like: + // unique="synology_88f6282_413j" + // Extract the CPU in the middle (88f6282 in the above example). + s := bufio.NewScanner(f) + for s.Scan() { + l := s.Text() + if !strings.HasPrefix(l, "unique=") { + continue + } + parts := strings.SplitN(l, "_", 3) + if len(parts) != 3 { + return "", fmt.Errorf(`malformed %q: found %q, expected format like 'unique="synology_$cpu_$model'`, path, l) + } + return parts[1], nil + } + return "", fmt.Errorf(`missing "unique=" field in %q`, path) +} + +func (up *Updater) updateDebLike() error { + if err := requireRoot(); err != nil { + return err + } + if err := exec.Command("dpkg", "--status", "tailscale").Run(); err != nil && isExitError(err) { + // Tailscale was not installed via apt, update via tarball download + // instead. + return up.updateLinuxBinary() + } + ver, err := requestedTailscaleVersion(up.Version, up.track) + if err != nil { + return err + } + if !up.confirm(ver) { + return nil + } + + if updated, err := updateDebianAptSourcesList(up.track); err != nil { + return err + } else if updated { + up.Logf("Updated %s to use the %s track", aptSourcesFile, up.track) + } + + cmd := exec.Command("apt-get", "update", + // Only update the tailscale repo, not the other ones, treating + // the tailscale.list file as the main "sources.list" file. + "-o", "Dir::Etc::SourceList=sources.list.d/tailscale.list", + // Disable the "sources.list.d" directory: + "-o", "Dir::Etc::SourceParts=-", + // Don't forget about packages in the other repos just because + // we're not updating them: + "-o", "APT::Get::List-Cleanup=0", + ) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return err + } + + cmd = exec.Command("apt-get", "install", "--yes", "--allow-downgrades", "tailscale="+ver) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return err + } + + return nil +} + +const aptSourcesFile = "/etc/apt/sources.list.d/tailscale.list" + +// updateDebianAptSourcesList updates the /etc/apt/sources.list.d/tailscale.list +// file to make sure it has the provided track (stable or unstable) in it. +// +// If it already has the right track (including containing both stable and +// unstable), it does nothing. +func updateDebianAptSourcesList(dstTrack string) (rewrote bool, err error) { + was, err := os.ReadFile(aptSourcesFile) + if err != nil { + return false, err + } + newContent, err := updateDebianAptSourcesListBytes(was, dstTrack) + if err != nil { + return false, err + } + if bytes.Equal(was, newContent) { + return false, nil + } + return true, os.WriteFile(aptSourcesFile, newContent, 0644) +} + +func updateDebianAptSourcesListBytes(was []byte, dstTrack string) (newContent []byte, err error) { + trackURLPrefix := []byte("https://pkgs.tailscale.com/" + dstTrack + "/") + var buf bytes.Buffer + var changes int + bs := bufio.NewScanner(bytes.NewReader(was)) + hadCorrect := false + commentLine := regexp.MustCompile(`^\s*\#`) + pkgsURL := regexp.MustCompile(`\bhttps://pkgs\.tailscale\.com/((un)?stable)/`) + for bs.Scan() { + line := bs.Bytes() + if !commentLine.Match(line) { + line = pkgsURL.ReplaceAllFunc(line, func(m []byte) []byte { + if bytes.Equal(m, trackURLPrefix) { + hadCorrect = true + } else { + changes++ + } + return trackURLPrefix + }) + } + buf.Write(line) + buf.WriteByte('\n') + } + if hadCorrect || (changes == 1 && bytes.Equal(bytes.TrimSpace(was), bytes.TrimSpace(buf.Bytes()))) { + // Unchanged or close enough. + return was, nil + } + if changes != 1 { + // No changes, or an unexpected number of changes (what?). Bail. + // They probably editted it by hand and we don't know what to do. + return nil, fmt.Errorf("unexpected/unsupported %s contents", aptSourcesFile) + } + return buf.Bytes(), nil +} + +func (up *Updater) updateArchLike() error { + if err := exec.Command("pacman", "--query", "tailscale").Run(); err != nil && isExitError(err) { + // Tailscale was not installed via pacman, update via tarball download + // instead. + return up.updateLinuxBinary() + } + // Arch maintainer asked us not to implement "tailscale update" or + // auto-updates on Arch-based distros: + // https://github.com/tailscale/tailscale/issues/6995#issuecomment-1687080106 + return errors.New(`individual package updates are not supported on Arch-based distros, only full-system updates are: https://wiki.archlinux.org/title/System_maintenance#Partial_upgrades_are_unsupported. +you can use "pacman --sync --refresh --sysupgrade" or "pacman -Syu" to upgrade the system, including Tailscale.`) +} + +const yumRepoConfigFile = "/etc/yum.repos.d/tailscale.repo" + +// updateFedoraLike updates tailscale on any distros in the Fedora family, +// specifically anything that uses "dnf" or "yum" package managers. The actual +// package manager is passed via packageManager. +func (up *Updater) updateFedoraLike(packageManager string) func() error { + return func() (err error) { + if err := requireRoot(); err != nil { + return err + } + if err := exec.Command(packageManager, "info", "--installed", "tailscale").Run(); err != nil && isExitError(err) { + // Tailscale was not installed via yum/dnf, update via tarball + // download instead. + return up.updateLinuxBinary() + } + defer func() { + if err != nil { + err = fmt.Errorf(`%w; you can try updating using "%s upgrade tailscale"`, err, packageManager) + } + }() + + ver, err := requestedTailscaleVersion(up.Version, up.track) + if err != nil { + return err + } + if !up.confirm(ver) { + return nil + } + + if updated, err := updateYUMRepoTrack(yumRepoConfigFile, up.track); err != nil { + return err + } else if updated { + up.Logf("Updated %s to use the %s track", yumRepoConfigFile, up.track) + } + + cmd := exec.Command(packageManager, "install", "--assumeyes", fmt.Sprintf("tailscale-%s-1", ver)) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return err + } + return nil + } +} + +// updateYUMRepoTrack updates the repoFile file to make sure it has the +// provided track (stable or unstable) in it. +func updateYUMRepoTrack(repoFile, dstTrack string) (rewrote bool, err error) { + was, err := os.ReadFile(repoFile) + if err != nil { + return false, err + } + + urlRe := regexp.MustCompile(`^(baseurl|gpgkey)=https://pkgs\.tailscale\.com/(un)?stable/`) + urlReplacement := fmt.Sprintf("$1=https://pkgs.tailscale.com/%s/", dstTrack) + + s := bufio.NewScanner(bytes.NewReader(was)) + newContent := bytes.NewBuffer(make([]byte, 0, len(was))) + for s.Scan() { + line := s.Text() + // Handle repo section name, like "[tailscale-stable]". + if len(line) > 0 && line[0] == '[' { + if !strings.HasPrefix(line, "[tailscale-") { + return false, fmt.Errorf("%q does not look like a tailscale repo file, it contains an unexpected %q section", repoFile, line) + } + fmt.Fprintf(newContent, "[tailscale-%s]\n", dstTrack) + continue + } + // Update the track mentioned in repo name. + if strings.HasPrefix(line, "name=") { + fmt.Fprintf(newContent, "name=Tailscale %s\n", dstTrack) + continue + } + // Update the actual repo URLs. + if strings.HasPrefix(line, "baseurl=") || strings.HasPrefix(line, "gpgkey=") { + fmt.Fprintln(newContent, urlRe.ReplaceAllString(line, urlReplacement)) + continue + } + fmt.Fprintln(newContent, line) + } + if bytes.Equal(was, newContent.Bytes()) { + return false, nil + } + return true, os.WriteFile(repoFile, newContent.Bytes(), 0644) +} + +func (up *Updater) updateAlpineLike() (err error) { + if up.Version != "" { + return errors.New("installing a specific version on Alpine-based distros is not supported") + } + if err := requireRoot(); err != nil { + return err + } + if err := exec.Command("apk", "info", "--installed", "tailscale").Run(); err != nil && isExitError(err) { + // Tailscale was not installed via apk, update via tarball download + // instead. + return up.updateLinuxBinary() + } + + defer func() { + if err != nil { + err = fmt.Errorf(`%w; you can try updating using "apk upgrade tailscale"`, err) + } + }() + + out, err := exec.Command("apk", "update").CombinedOutput() + if err != nil { + return fmt.Errorf("failed refresh apk repository indexes: %w, output: %q", err, out) + } + out, err = exec.Command("apk", "info", "tailscale").CombinedOutput() + if err != nil { + return fmt.Errorf("failed checking apk for latest tailscale version: %w, output: %q", err, out) + } + ver, err := parseAlpinePackageVersion(out) + if err != nil { + return fmt.Errorf(`failed to parse latest version from "apk info tailscale": %w`, err) + } + if !up.confirm(ver) { + return nil + } + + cmd := exec.Command("apk", "upgrade", "tailscale") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed tailscale update using apk: %w", err) + } + return nil +} + +func parseAlpinePackageVersion(out []byte) (string, error) { + s := bufio.NewScanner(bytes.NewReader(out)) + for s.Scan() { + // The line should look like this: + // tailscale-1.44.2-r0 description: + line := strings.TrimSpace(s.Text()) + if !strings.HasPrefix(line, "tailscale-") { + continue + } + parts := strings.SplitN(line, "-", 3) + if len(parts) < 3 { + return "", fmt.Errorf("malformed info line: %q", line) + } + return parts[1], nil + } + return "", errors.New("tailscale version not found in output") +} + +func (up *Updater) updateMacSys() error { + return errors.New("NOTREACHED: On MacSys builds, `tailscale update` is handled in Swift to launch the GUI updater") +} + +func (up *Updater) updateMacAppStore() error { + out, err := exec.Command("defaults", "read", "/Library/Preferences/com.apple.commerce.plist", "AutoUpdate").CombinedOutput() + if err != nil { + return fmt.Errorf("can't check App Store auto-update setting: %w, output: %q", err, string(out)) + } + const on = "1\n" + if string(out) != on { + up.Logf("NOTE: Automatic updating for App Store apps is turned off. You can change this setting in System Settings (search for ‘update’).") + } + + out, err = exec.Command("softwareupdate", "--list").CombinedOutput() + if err != nil { + return fmt.Errorf("can't check App Store for available updates: %w, output: %q", err, string(out)) + } + + newTailscale := parseSoftwareupdateList(out) + if newTailscale == "" { + up.Logf("no Tailscale update available") + return nil + } + + newTailscaleVer := strings.TrimPrefix(newTailscale, "Tailscale-") + if !up.confirm(newTailscaleVer) { + return nil + } + + cmd := exec.Command("sudo", "softwareupdate", "--install", newTailscale) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("can't install App Store update for Tailscale: %w", err) + } + return nil +} + +var macOSAppStoreListPattern = regexp.MustCompile(`(?m)^\s+\*\s+Label:\s*(Tailscale-\d[\d\.]+)`) + +// parseSoftwareupdateList searches the output of `softwareupdate --list` on +// Darwin and returns the matching Tailscale package label. If there is none, +// returns the empty string. +// +// See TestParseSoftwareupdateList for example inputs. +func parseSoftwareupdateList(stdout []byte) string { + matches := macOSAppStoreListPattern.FindSubmatch(stdout) + if len(matches) < 2 { + return "" + } + return string(matches[1]) +} + +// winMSIEnv is the environment variable that, if set, is the MSI file for the +// update command to install. It's passed like this so we can stop the +// tailscale.exe process from running before the msiexec process runs and tries +// to overwrite ourselves. +const winMSIEnv = "TS_UPDATE_WIN_MSI" + +var ( + verifyAuthenticode func(string) error // or nil on non-Windows + markTempFileFunc func(string) error // or nil on non-Windows +) + +func (up *Updater) updateWindows() error { + if msi := os.Getenv(winMSIEnv); msi != "" { + up.Logf("installing %v ...", msi) + if err := up.installMSI(msi); err != nil { + up.Logf("MSI install failed: %v", err) + return err + } + up.Logf("success.") + return nil + } + ver, err := requestedTailscaleVersion(up.Version, up.track) + if err != nil { + return err + } + arch := runtime.GOARCH + if arch == "386" { + arch = "x86" + } + + if !up.confirm(ver) { + return nil + } + if !winutil.IsCurrentProcessElevated() { + return errors.New("must be run as Administrator") + } + + tsDir := filepath.Join(os.Getenv("ProgramData"), "Tailscale") + msiDir := filepath.Join(tsDir, "MSICache") + if fi, err := os.Stat(tsDir); err != nil { + return fmt.Errorf("expected %s to exist, got stat error: %w", tsDir, err) + } else if !fi.IsDir() { + return fmt.Errorf("expected %s to be a directory; got %v", tsDir, fi.Mode()) + } + if err := os.MkdirAll(msiDir, 0700); err != nil { + return err + } + pkgsPath := fmt.Sprintf("%s/tailscale-setup-%s-%s.msi", up.track, ver, arch) + msiTarget := filepath.Join(msiDir, path.Base(pkgsPath)) + if err := up.downloadURLToFile(pkgsPath, msiTarget); err != nil { + return err + } + + up.Logf("verifying MSI authenticode...") + if err := verifyAuthenticode(msiTarget); err != nil { + return fmt.Errorf("authenticode verification of %s failed: %w", msiTarget, err) + } + up.Logf("authenticode verification succeeded") + + up.Logf("making tailscale.exe copy to switch to...") + selfCopy, err := makeSelfCopy() + if err != nil { + return err + } + defer os.Remove(selfCopy) + up.Logf("running tailscale.exe copy for final install...") + + cmd := exec.Command(selfCopy, "update") + cmd.Env = append(os.Environ(), winMSIEnv+"="+msiTarget) + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr + cmd.Stdin = os.Stdin + if err := cmd.Start(); err != nil { + return err + } + // Once it's started, exit ourselves, so the binary is free + // to be replaced. + os.Exit(0) + panic("unreachable") +} + +func (up *Updater) installMSI(msi string) error { + var err error + for tries := 0; tries < 2; tries++ { + cmd := exec.Command("msiexec.exe", "/i", filepath.Base(msi), "/quiet", "/promptrestart", "/qn") + cmd.Dir = filepath.Dir(msi) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Stdin = os.Stdin + err = cmd.Run() + if err == nil { + break + } + uninstallVersion := version.Short() + if v := os.Getenv("TS_DEBUG_UNINSTALL_VERSION"); v != "" { + uninstallVersion = v + } + // Assume it's a downgrade, which msiexec won't permit. Uninstall our current version first. + up.Logf("Uninstalling current version %q for downgrade...", uninstallVersion) + cmd = exec.Command("msiexec.exe", "/x", msiUUIDForVersion(uninstallVersion), "/norestart", "/qn") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Stdin = os.Stdin + err = cmd.Run() + up.Logf("msiexec uninstall: %v", err) + } + return err +} + +func msiUUIDForVersion(ver string) string { + arch := runtime.GOARCH + if arch == "386" { + arch = "x86" + } + track, err := versionToTrack(ver) + if err != nil { + track = UnstableTrack + } + msiURL := fmt.Sprintf("https://pkgs.tailscale.com/%s/tailscale-setup-%s-%s.msi", track, ver, arch) + return "{" + strings.ToUpper(uuid.NewSHA1(uuid.NameSpaceURL, []byte(msiURL)).String()) + "}" +} + +func makeSelfCopy() (tmpPathExe string, err error) { + selfExe, err := os.Executable() + if err != nil { + return "", err + } + f, err := os.Open(selfExe) + if err != nil { + return "", err + } + defer f.Close() + f2, err := os.CreateTemp("", "tailscale-updater-*.exe") + if err != nil { + return "", err + } + if f := markTempFileFunc; f != nil { + if err := f(f2.Name()); err != nil { + return "", err + } + } + if _, err := io.Copy(f2, f); err != nil { + f2.Close() + return "", err + } + return f2.Name(), f2.Close() +} + +func (up *Updater) downloadURLToFile(pathSrc, fileDst string) (ret error) { + c, err := distsign.NewClient(up.Logf, up.PkgsAddr) + if err != nil { + return err + } + return c.Download(context.Background(), pathSrc, fileDst) +} + +func (up *Updater) updateFreeBSD() (err error) { + if up.Version != "" { + return errors.New("installing a specific version on FreeBSD is not supported") + } + if err := requireRoot(); err != nil { + return err + } + if err := exec.Command("pkg", "query", "%n", "tailscale").Run(); err != nil && isExitError(err) { + // Tailscale was not installed via pkg and we don't pre-compile + // binaries for it. + return errors.New("Tailscale was not installed via pkg, binary updates on FreeBSD are not supported; please reinstall Tailscale using pkg or update manually") + } + + defer func() { + if err != nil { + err = fmt.Errorf(`%w; you can try updating using "pkg upgrade tailscale"`, err) + } + }() + + out, err := exec.Command("pkg", "update").CombinedOutput() + if err != nil { + return fmt.Errorf("failed refresh pkg repository indexes: %w, output: %q", err, out) + } + out, err = exec.Command("pkg", "rquery", "%v", "tailscale").CombinedOutput() + if err != nil { + return fmt.Errorf("failed checking pkg for latest tailscale version: %w, output: %q", err, out) + } + ver := string(bytes.TrimSpace(out)) + if !up.confirm(ver) { + return nil + } + + cmd := exec.Command("pkg", "upgrade", "tailscale") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed tailscale update using pkg: %w", err) + } + return nil +} + +func (up *Updater) updateLinuxBinary() error { + ver, err := requestedTailscaleVersion(up.Version, up.track) + if err != nil { + return err + } + if !up.confirm(ver) { + return nil + } + // Root is needed to overwrite binaries and restart systemd unit. + if err := requireRoot(); err != nil { + return err + } + + dlPath, err := up.downloadLinuxTarball(ver) + if err != nil { + return err + } + up.Logf("Extracting %q", dlPath) + if err := up.unpackLinuxTarball(dlPath); err != nil { + return err + } + if err := os.Remove(dlPath); err != nil { + up.Logf("failed to clean up %q: %v", dlPath, err) + } + if err := restartSystemdUnit(context.Background()); err != nil { + if errors.Is(err, errors.ErrUnsupported) { + up.Logf("Tailscale binaries updated successfully.\nPlease restart tailscaled to finish the update.") + } else { + up.Logf("Tailscale binaries updated successfully, but failed to restart tailscaled: %s.\nPlease restart tailscaled to finish the update.", err) + } + } else { + up.Logf("Success") + } + + return nil +} + +func (up *Updater) downloadLinuxTarball(ver string) (string, error) { + dlDir, err := os.UserCacheDir() + if err != nil { + return "", err + } + dlDir = filepath.Join(dlDir, "tailscale-update") + if err := os.MkdirAll(dlDir, 0700); err != nil { + return "", err + } + pkgsPath := fmt.Sprintf("%s/tailscale_%s_%s.tgz", up.track, ver, runtime.GOARCH) + dlPath := filepath.Join(dlDir, path.Base(pkgsPath)) + if err := up.downloadURLToFile(pkgsPath, dlPath); err != nil { + return "", err + } + return dlPath, nil +} + +func (up *Updater) unpackLinuxTarball(path string) error { + tailscale, tailscaled, err := binaryPaths() + if err != nil { + return err + } + f, err := os.Open(path) + if err != nil { + return err + } + defer f.Close() + gr, err := gzip.NewReader(f) + if err != nil { + return err + } + defer gr.Close() + tr := tar.NewReader(gr) + files := make(map[string]int) + wantFiles := map[string]int{ + "tailscale": 1, + "tailscaled": 1, + } + for { + th, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("failed extracting %q: %w", path, err) + } + // TODO(awly): try to also extract tailscaled.service. The tricky part + // is fixing up binary paths in that file if they differ from where + // local tailscale/tailscaled are installed. Also, this may not be a + // systemd distro. + switch filepath.Base(th.Name) { + case "tailscale": + files["tailscale"]++ + if err := writeFile(tr, tailscale+".new", 0755); err != nil { + return fmt.Errorf("failed extracting the new tailscale binary from %q: %w", path, err) + } + case "tailscaled": + files["tailscaled"]++ + if err := writeFile(tr, tailscaled+".new", 0755); err != nil { + return fmt.Errorf("failed extracting the new tailscaled binary from %q: %w", path, err) + } + } + } + if !maps.Equal(files, wantFiles) { + return fmt.Errorf("%q has missing or duplicate files: got %v, want %v", path, files, wantFiles) + } + + // Only place the files in final locations after everything extracted correctly. + if err := os.Rename(tailscale+".new", tailscale); err != nil { + return err + } + up.Logf("Updated %s", tailscale) + if err := os.Rename(tailscaled+".new", tailscaled); err != nil { + return err + } + up.Logf("Updated %s", tailscaled) + return nil +} + +func writeFile(r io.Reader, path string, perm os.FileMode) error { + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove existing file at %q: %w", path, err) + } + f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_EXCL, perm) + if err != nil { + return err + } + defer f.Close() + if _, err := io.Copy(f, r); err != nil { + return err + } + return f.Close() +} + +// Var allows overriding this in tests. +var binaryPaths = func() (tailscale, tailscaled string, err error) { + // This can be either tailscale or tailscaled. + this, err := os.Executable() + if err != nil { + return "", "", err + } + otherName := "tailscaled" + if filepath.Base(this) == "tailscaled" { + otherName = "tailscale" + } + // Try to find the other binary in the same directory. + other := filepath.Join(filepath.Dir(this), otherName) + _, err = os.Stat(other) + if os.IsNotExist(err) { + // If it's not in the same directory, try to find it in $PATH. + other, err = exec.LookPath(otherName) + } + if err != nil { + return "", "", fmt.Errorf("cannot find %q in neither %q nor $PATH: %w", otherName, filepath.Dir(this), err) + } + if otherName == "tailscaled" { + return this, other, nil + } else { + return other, this, nil + } +} + +func haveExecutable(name string) bool { + path, err := exec.LookPath(name) + return err == nil && path != "" +} + +func requestedTailscaleVersion(ver, track string) (string, error) { + if ver != "" { + return ver, nil + } + return LatestTailscaleVersion(track) +} + +// LatestTailscaleVersion returns the latest released version for the given +// track from pkgs.tailscale.com. +func LatestTailscaleVersion(track string) (string, error) { + if track == CurrentTrack { + if version.IsUnstableBuild() { + track = UnstableTrack + } else { + track = StableTrack + } + } + + latest, err := latestPackages(track) + if err != nil { + return "", err + } + if latest.Version == "" { + return "", fmt.Errorf("no latest version found for %q track", track) + } + return latest.Version, nil +} + +type trackPackages struct { + Version string + Tarballs map[string]string + TarballsVersion string + Exes []string + ExesVersion string + MSIs map[string]string + MSIsVersion string + MacZips map[string]string + MacZipsVersion string + SPKs map[string]map[string]string + SPKsVersion string +} + +func latestPackages(track string) (*trackPackages, error) { + url := fmt.Sprintf("https://pkgs.tailscale.com/%s/?mode=json&os=%s", track, runtime.GOOS) + res, err := http.Get(url) + if err != nil { + return nil, fmt.Errorf("fetching latest tailscale version: %w", err) + } + defer res.Body.Close() + var latest trackPackages + if err := json.NewDecoder(res.Body).Decode(&latest); err != nil { + return nil, fmt.Errorf("decoding JSON: %v: %w", res.Status, err) + } + return &latest, nil +} + +func requireRoot() error { + if os.Geteuid() == 0 { + return nil + } + switch runtime.GOOS { + case "linux": + return errors.New("must be root; use sudo") + case "freebsd", "openbsd": + return errors.New("must be root; use doas") + default: + return errors.New("must be root") + } +} + +func isExitError(err error) bool { + var exitErr *exec.ExitError + return errors.As(err, &exitErr) +} diff --git a/vendor/tailscale.com/clientupdate/clientupdate_windows.go b/vendor/tailscale.com/clientupdate/clientupdate_windows.go new file mode 100644 index 0000000000..2f6899a605 --- /dev/null +++ b/vendor/tailscale.com/clientupdate/clientupdate_windows.go @@ -0,0 +1,28 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Windows-specific stuff that can't go in clientupdate.go because it needs +// x/sys/windows. + +package clientupdate + +import ( + "golang.org/x/sys/windows" + "tailscale.com/util/winutil/authenticode" +) + +func init() { + markTempFileFunc = markTempFileWindows + verifyAuthenticode = verifyTailscale +} + +func markTempFileWindows(name string) error { + name16 := windows.StringToUTF16Ptr(name) + return windows.MoveFileEx(name16, nil, windows.MOVEFILE_DELAY_UNTIL_REBOOT) +} + +const certSubjectTailscale = "Tailscale Inc." + +func verifyTailscale(path string) error { + return authenticode.Verify(path, certSubjectTailscale) +} diff --git a/vendor/tailscale.com/clientupdate/distsign/distsign.go b/vendor/tailscale.com/clientupdate/distsign/distsign.go new file mode 100644 index 0000000000..b48321f8f5 --- /dev/null +++ b/vendor/tailscale.com/clientupdate/distsign/distsign.go @@ -0,0 +1,485 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package distsign implements signature and validation of arbitrary +// distributable files. +// +// There are 3 parties in this exchange: +// - builder, which creates files, signs them with signing keys and publishes +// to server +// - server, which distributes public signing keys, files and signatures +// - client, which downloads files and signatures from server, and validates +// the signatures +// +// There are 2 types of keys: +// - signing keys, that sign individual distributable files on the builder +// - root keys, that sign signing keys and are kept offline +// +// root keys -(sign)-> signing keys -(sign)-> files +// +// All keys are asymmetric Ed25519 key pairs. +// +// The server serves static files under some known prefix. The kinds of files are: +// - distsign.pub - bundle of PEM-encoded public signing keys +// - distsign.pub.sig - signature of distsign.pub using one of the root keys +// - $file - any distributable file +// - $file.sig - signature of $file using any of the signing keys +// +// The root public keys are baked into the client software at compile time. +// These keys are long-lived and prove the validity of current signing keys +// from distsign.pub. To rotate root keys, a new client release must be +// published, they are not rotated dynamically. There are multiple root keys in +// different locations specifically to allow this rotation without using the +// discarded root key for any new signatures. +// +// The signing public keys are fetched by the client dynamically before every +// download and can be rotated more readily, assuming that most deployed +// clients trust the root keys used to issue fresh signing keys. +package distsign + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/binary" + "encoding/pem" + "errors" + "fmt" + "hash" + "io" + "log" + "net/http" + "net/url" + "os" + "time" + + "github.com/hdevalence/ed25519consensus" + "golang.org/x/crypto/blake2s" + "tailscale.com/net/tshttpproxy" + "tailscale.com/types/logger" + "tailscale.com/util/must" +) + +const ( + pemTypeRootPrivate = "ROOT PRIVATE KEY" + pemTypeRootPublic = "ROOT PUBLIC KEY" + pemTypeSigningPrivate = "SIGNING PRIVATE KEY" + pemTypeSigningPublic = "SIGNING PUBLIC KEY" + + downloadSizeLimit = 1 << 29 // 512MB + signingKeysSizeLimit = 1 << 20 // 1MB + signatureSizeLimit = ed25519.SignatureSize +) + +// RootKey is a root key used to sign signing keys. +type RootKey struct { + k ed25519.PrivateKey +} + +// GenerateRootKey generates a new root key pair and encodes it as PEM. +func GenerateRootKey() (priv, pub []byte, err error) { + pub, priv, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + return pem.EncodeToMemory(&pem.Block{ + Type: pemTypeRootPrivate, + Bytes: []byte(priv), + }), pem.EncodeToMemory(&pem.Block{ + Type: pemTypeRootPublic, + Bytes: []byte(pub), + }), nil +} + +// ParseRootKey parses the PEM-encoded private root key. The key must be in the +// same format as returned by GenerateRootKey. +func ParseRootKey(privKey []byte) (*RootKey, error) { + k, err := parsePrivateKey(privKey, pemTypeRootPrivate) + if err != nil { + return nil, fmt.Errorf("failed to parse root key: %w", err) + } + return &RootKey{k: k}, nil +} + +// SignSigningKeys signs the bundle of public signing keys. The bundle must be +// a sequence of PEM blocks joined with newlines. +func (r *RootKey) SignSigningKeys(pubBundle []byte) ([]byte, error) { + if _, err := ParseSigningKeyBundle(pubBundle); err != nil { + return nil, err + } + return ed25519.Sign(r.k, pubBundle), nil +} + +// SigningKey is a signing key used to sign packages. +type SigningKey struct { + k ed25519.PrivateKey +} + +// GenerateSigningKey generates a new signing key pair and encodes it as PEM. +func GenerateSigningKey() (priv, pub []byte, err error) { + pub, priv, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + return pem.EncodeToMemory(&pem.Block{ + Type: pemTypeSigningPrivate, + Bytes: []byte(priv), + }), pem.EncodeToMemory(&pem.Block{ + Type: pemTypeSigningPublic, + Bytes: []byte(pub), + }), nil +} + +// ParseSigningKey parses the PEM-encoded private signing key. The key must be +// in the same format as returned by GenerateSigningKey. +func ParseSigningKey(privKey []byte) (*SigningKey, error) { + k, err := parsePrivateKey(privKey, pemTypeSigningPrivate) + if err != nil { + return nil, fmt.Errorf("failed to parse root key: %w", err) + } + return &SigningKey{k: k}, nil +} + +// SignPackageHash signs the hash and the length of a package. Use PackageHash +// to compute the inputs. +func (s *SigningKey) SignPackageHash(hash []byte, len int64) ([]byte, error) { + if len <= 0 { + return nil, fmt.Errorf("package length must be positive, got %d", len) + } + msg := binary.LittleEndian.AppendUint64(hash, uint64(len)) + return ed25519.Sign(s.k, msg), nil +} + +// PackageHash is a hash.Hash that counts the number of bytes written. Use it +// to get the hash and length inputs to SigningKey.SignPackageHash. +type PackageHash struct { + hash.Hash + len int64 +} + +// NewPackageHash returns an initialized PackageHash using BLAKE2s. +func NewPackageHash() *PackageHash { + h, err := blake2s.New256(nil) + if err != nil { + // Should never happen with a nil key passed to blake2s. + panic(err) + } + return &PackageHash{Hash: h} +} + +func (ph *PackageHash) Write(b []byte) (int, error) { + ph.len += int64(len(b)) + return ph.Hash.Write(b) +} + +// Reset the PackageHash to its initial state. +func (ph *PackageHash) Reset() { + ph.len = 0 + ph.Hash.Reset() +} + +// Len returns the total number of bytes written. +func (ph *PackageHash) Len() int64 { return ph.len } + +// Client downloads and validates files from a distribution server. +type Client struct { + logf logger.Logf + roots []ed25519.PublicKey + pkgsAddr *url.URL +} + +// NewClient returns a new client for distribution server located at pkgsAddr, +// and uses embedded root keys from the roots/ subdirectory of this package. +func NewClient(logf logger.Logf, pkgsAddr string) (*Client, error) { + if logf == nil { + logf = log.Printf + } + u, err := url.Parse(pkgsAddr) + if err != nil { + return nil, fmt.Errorf("invalid pkgsAddr %q: %w", pkgsAddr, err) + } + return &Client{logf: logf, roots: roots(), pkgsAddr: u}, nil +} + +func (c *Client) url(path string) string { + return c.pkgsAddr.JoinPath(path).String() +} + +// Download fetches a file at path srcPath from pkgsAddr passed in NewClient. +// The file is downloaded to dstPath and its signature is validated using the +// embedded root keys. Download returns an error if anything goes wrong with +// the actual file download or with signature validation. +func (c *Client) Download(ctx context.Context, srcPath, dstPath string) error { + // Always fetch a fresh signing key. + sigPub, err := c.signingKeys() + if err != nil { + return err + } + + srcURL := c.url(srcPath) + sigURL := srcURL + ".sig" + + c.logf("Downloading %q", srcURL) + dstPathUnverified := dstPath + ".unverified" + hash, len, err := c.download(ctx, srcURL, dstPathUnverified, downloadSizeLimit) + if err != nil { + return err + } + c.logf("Downloading %q", sigURL) + sig, err := fetch(sigURL, signatureSizeLimit) + if err != nil { + // Best-effort clean up of downloaded package. + os.Remove(dstPathUnverified) + return err + } + msg := binary.LittleEndian.AppendUint64(hash, uint64(len)) + if !VerifyAny(sigPub, msg, sig) { + // Best-effort clean up of downloaded package. + os.Remove(dstPathUnverified) + return fmt.Errorf("signature %q for file %q does not validate with the current release signing key; either you are under attack, or attempting to download an old version of Tailscale which was signed with an older signing key", sigURL, srcURL) + } + c.logf("Signature OK") + + if err := os.Rename(dstPathUnverified, dstPath); err != nil { + return fmt.Errorf("failed to move %q to %q after signature validation", dstPathUnverified, dstPath) + } + + return nil +} + +// ValidateLocalBinary fetches the latest signature associated with the binary +// at srcURLPath and uses it to validate the file located on disk via +// localFilePath. ValidateLocalBinary returns an error if anything goes wrong +// with the signature download or with signature validation. +func (c *Client) ValidateLocalBinary(srcURLPath, localFilePath string) error { + // Always fetch a fresh signing key. + sigPub, err := c.signingKeys() + if err != nil { + return err + } + + srcURL := c.url(srcURLPath) + sigURL := srcURL + ".sig" + + localFile, err := os.Open(localFilePath) + if err != nil { + return err + } + defer localFile.Close() + + h := NewPackageHash() + _, err = io.Copy(h, localFile) + if err != nil { + return err + } + hash, hashLen := h.Sum(nil), h.Len() + + c.logf("Downloading %q", sigURL) + sig, err := fetch(sigURL, signatureSizeLimit) + if err != nil { + return err + } + + msg := binary.LittleEndian.AppendUint64(hash, uint64(hashLen)) + if !VerifyAny(sigPub, msg, sig) { + return fmt.Errorf("signature %q for file %q does not validate with the current release signing key; either you are under attack, or attempting to download an old version of Tailscale which was signed with an older signing key", sigURL, localFilePath) + } + c.logf("Signature OK") + + return nil +} + +// signingKeys fetches current signing keys from the server and validates them +// against the roots. Should be called before validation of any downloaded file +// to get the fresh keys. +func (c *Client) signingKeys() ([]ed25519.PublicKey, error) { + keyURL := c.url("distsign.pub") + sigURL := keyURL + ".sig" + raw, err := fetch(keyURL, signingKeysSizeLimit) + if err != nil { + return nil, err + } + sig, err := fetch(sigURL, signatureSizeLimit) + if err != nil { + return nil, err + } + if !VerifyAny(c.roots, raw, sig) { + return nil, fmt.Errorf("signature %q for key %q does not validate with any known root key; either you are under attack, or running a very old version of Tailscale with outdated root keys", sigURL, keyURL) + } + + keys, err := ParseSigningKeyBundle(raw) + if err != nil { + return nil, fmt.Errorf("cannot parse signing key bundle from %q: %w", keyURL, err) + } + return keys, nil +} + +// fetch reads the response body from url into memory, up to limit bytes. +func fetch(url string, limit int64) ([]byte, error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + return io.ReadAll(io.LimitReader(resp.Body, limit)) +} + +// download writes the response body of url into a local file at dst, up to +// limit bytes. On success, the returned value is a BLAKE2s hash of the file. +func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]byte, int64, error) { + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.Proxy = tshttpproxy.ProxyFromEnvironment + defer tr.CloseIdleConnections() + hc := &http.Client{Transport: tr} + + quickCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + headReq := must.Get(http.NewRequestWithContext(quickCtx, http.MethodHead, url, nil)) + + res, err := hc.Do(headReq) + if err != nil { + return nil, 0, err + } + if res.StatusCode != http.StatusOK { + return nil, 0, fmt.Errorf("HEAD %q: %v", url, res.Status) + } + if res.ContentLength <= 0 { + return nil, 0, fmt.Errorf("HEAD %q: unexpected Content-Length %v", url, res.ContentLength) + } + c.logf("Download size: %v", res.ContentLength) + + dlReq := must.Get(http.NewRequestWithContext(ctx, http.MethodGet, url, nil)) + dlRes, err := hc.Do(dlReq) + if err != nil { + return nil, 0, err + } + defer dlRes.Body.Close() + // TODO(bradfitz): resume from existing partial file on disk + if dlRes.StatusCode != http.StatusOK { + return nil, 0, fmt.Errorf("GET %q: %v", url, dlRes.Status) + } + + of, err := os.Create(dst) + if err != nil { + return nil, 0, err + } + defer of.Close() + pw := &progressWriter{total: res.ContentLength, logf: c.logf} + h := NewPackageHash() + n, err := io.Copy(io.MultiWriter(of, h, pw), io.LimitReader(dlRes.Body, limit)) + if err != nil { + return nil, n, err + } + if n != res.ContentLength { + return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, res.ContentLength) + } + if err := dlRes.Body.Close(); err != nil { + return nil, n, err + } + if err := of.Close(); err != nil { + return nil, n, err + } + pw.print() + + return h.Sum(nil), h.Len(), nil +} + +type progressWriter struct { + done int64 + total int64 + lastPrint time.Time + logf logger.Logf +} + +func (pw *progressWriter) Write(p []byte) (n int, err error) { + pw.done += int64(len(p)) + if time.Since(pw.lastPrint) > 2*time.Second { + pw.print() + } + return len(p), nil +} + +func (pw *progressWriter) print() { + pw.lastPrint = time.Now() + pw.logf("Downloaded %v/%v (%.1f%%)", pw.done, pw.total, float64(pw.done)/float64(pw.total)*100) +} + +func parsePrivateKey(data []byte, typeTag string) (ed25519.PrivateKey, error) { + b, rest := pem.Decode(data) + if b == nil { + return nil, errors.New("failed to decode PEM data") + } + if len(rest) > 0 { + return nil, errors.New("trailing PEM data") + } + if b.Type != typeTag { + return nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag) + } + if len(b.Bytes) != ed25519.PrivateKeySize { + return nil, errors.New("private key has incorrect length for an Ed25519 private key") + } + return ed25519.PrivateKey(b.Bytes), nil +} + +// ParseSigningKeyBundle parses the bundle of PEM-encoded public signing keys. +func ParseSigningKeyBundle(bundle []byte) ([]ed25519.PublicKey, error) { + return parsePublicKeyBundle(bundle, pemTypeSigningPublic) +} + +// ParseRootKeyBundle parses the bundle of PEM-encoded public root keys. +func ParseRootKeyBundle(bundle []byte) ([]ed25519.PublicKey, error) { + return parsePublicKeyBundle(bundle, pemTypeRootPublic) +} + +func parsePublicKeyBundle(bundle []byte, typeTag string) ([]ed25519.PublicKey, error) { + var keys []ed25519.PublicKey + for len(bundle) > 0 { + pub, rest, err := parsePublicKey(bundle, typeTag) + if err != nil { + return nil, err + } + keys = append(keys, pub) + bundle = rest + } + if len(keys) == 0 { + return nil, errors.New("no signing keys found in the bundle") + } + return keys, nil +} + +func parseSinglePublicKey(data []byte, typeTag string) (ed25519.PublicKey, error) { + pub, rest, err := parsePublicKey(data, typeTag) + if err != nil { + return nil, err + } + if len(rest) > 0 { + return nil, errors.New("trailing PEM data") + } + return pub, err +} + +func parsePublicKey(data []byte, typeTag string) (pub ed25519.PublicKey, rest []byte, retErr error) { + b, rest := pem.Decode(data) + if b == nil { + return nil, nil, errors.New("failed to decode PEM data") + } + if b.Type != typeTag { + return nil, nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag) + } + if len(b.Bytes) != ed25519.PublicKeySize { + return nil, nil, errors.New("public key has incorrect length for an Ed25519 public key") + } + return ed25519.PublicKey(b.Bytes), rest, nil +} + +// VerifyAny verifies whether sig is valid for msg using any of the keys. +// VerifyAny will panic if any of the keys have the wrong size for Ed25519. +func VerifyAny(keys []ed25519.PublicKey, msg, sig []byte) bool { + for _, k := range keys { + if ed25519consensus.Verify(k, msg, sig) { + return true + } + } + return false +} diff --git a/vendor/tailscale.com/clientupdate/distsign/roots.go b/vendor/tailscale.com/clientupdate/distsign/roots.go new file mode 100644 index 0000000000..d5b47b7b62 --- /dev/null +++ b/vendor/tailscale.com/clientupdate/distsign/roots.go @@ -0,0 +1,54 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package distsign + +import ( + "crypto/ed25519" + "embed" + "errors" + "fmt" + "path" + "path/filepath" + "sync" +) + +//go:embed roots +var rootsFS embed.FS + +var roots = sync.OnceValue(func() []ed25519.PublicKey { + roots, err := parseRoots() + if err != nil { + panic(err) + } + return roots +}) + +func parseRoots() ([]ed25519.PublicKey, error) { + files, err := rootsFS.ReadDir("roots") + if err != nil { + return nil, err + } + var keys []ed25519.PublicKey + for _, f := range files { + if !f.Type().IsRegular() { + continue + } + if filepath.Ext(f.Name()) != ".pem" { + continue + } + raw, err := rootsFS.ReadFile(path.Join("roots", f.Name())) + if err != nil { + return nil, err + } + key, err := parseSinglePublicKey(raw, pemTypeRootPublic) + if err != nil { + return nil, fmt.Errorf("parsing root key %q: %w", f.Name(), err) + } + keys = append(keys, key) + } + if len(keys) == 0 { + return nil, errors.New("no embedded root keys, please check clientupdate/distsign/roots/") + } + return keys, nil +} diff --git a/vendor/tailscale.com/clientupdate/distsign/roots/crawshaw-root.pem b/vendor/tailscale.com/clientupdate/distsign/roots/crawshaw-root.pem new file mode 100644 index 0000000000..f80b9aec78 --- /dev/null +++ b/vendor/tailscale.com/clientupdate/distsign/roots/crawshaw-root.pem @@ -0,0 +1,3 @@ +-----BEGIN ROOT PUBLIC KEY----- +Psrabv2YNiEDhPlnLVSMtB5EKACm7zxvKxfvYD4i7X8= +-----END ROOT PUBLIC KEY----- diff --git a/vendor/tailscale.com/clientupdate/distsign/roots/distsign-dev-root-pub.pem b/vendor/tailscale.com/clientupdate/distsign/roots/distsign-dev-root-pub.pem new file mode 100644 index 0000000000..f21d898a01 --- /dev/null +++ b/vendor/tailscale.com/clientupdate/distsign/roots/distsign-dev-root-pub.pem @@ -0,0 +1,3 @@ +-----BEGIN ROOT PUBLIC KEY----- +Muw5GkO5mASsJ7k6kS+svfuanr6XcW9I7fPGtyqOTeI= +-----END ROOT PUBLIC KEY----- diff --git a/vendor/tailscale.com/clientupdate/systemd_linux.go b/vendor/tailscale.com/clientupdate/systemd_linux.go new file mode 100644 index 0000000000..810f7dd552 --- /dev/null +++ b/vendor/tailscale.com/clientupdate/systemd_linux.go @@ -0,0 +1,37 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package clientupdate + +import ( + "context" + "errors" + "fmt" + + "github.com/coreos/go-systemd/v22/dbus" +) + +func restartSystemdUnit(ctx context.Context) error { + c, err := dbus.NewWithContext(ctx) + if err != nil { + // Likely not a systemd-managed distro. + return errors.ErrUnsupported + } + defer c.Close() + if err := c.ReloadContext(ctx); err != nil { + return fmt.Errorf("failed to reload tailsacled.service: %w", err) + } + ch := make(chan string, 1) + if _, err := c.RestartUnitContext(ctx, "tailscaled.service", "replace", ch); err != nil { + return fmt.Errorf("failed to restart tailsacled.service: %w", err) + } + select { + case res := <-ch: + if res != "done" { + return fmt.Errorf("systemd service restart failed with result %q", res) + } + case <-ctx.Done(): + return ctx.Err() + } + return nil +} diff --git a/vendor/tailscale.com/clientupdate/systemd_other.go b/vendor/tailscale.com/clientupdate/systemd_other.go new file mode 100644 index 0000000000..b2412b6e44 --- /dev/null +++ b/vendor/tailscale.com/clientupdate/systemd_other.go @@ -0,0 +1,15 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package clientupdate + +import ( + "context" + "errors" +) + +func restartSystemdUnit(ctx context.Context) error { + return errors.ErrUnsupported +} diff --git a/vendor/tailscale.com/control/controlclient/auto.go b/vendor/tailscale.com/control/controlclient/auto.go index dab37a5100..fa5e2e1065 100644 --- a/vendor/tailscale.com/control/controlclient/auto.go +++ b/vendor/tailscale.com/control/controlclient/auto.go @@ -9,13 +9,14 @@ import ( "fmt" "net/http" "sync" + "sync/atomic" "time" "tailscale.com/health" "tailscale.com/logtail/backoff" "tailscale.com/net/sockstats" "tailscale.com/tailcfg" - "tailscale.com/types/empty" + "tailscale.com/tstime" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netmap" @@ -24,60 +25,127 @@ import ( ) type LoginGoal struct { - _ structs.Incomparable - wantLoggedIn bool // true if we *want* to be logged in - token *tailcfg.Oauth2Token // oauth token to use when logging in - flags LoginFlags // flags to use when logging in - url string // auth url that needs to be visited - loggedOutResult chan<- error + _ structs.Incomparable + token *tailcfg.Oauth2Token // oauth token to use when logging in + flags LoginFlags // flags to use when logging in + url string // auth url that needs to be visited } -func (g *LoginGoal) sendLogoutError(err error) { - if g.loggedOutResult == nil { - return +var _ Client = (*Auto)(nil) + +// waitUnpause waits until either the client is unpaused or the Auto client is +// shut down. It reports whether the client should keep running (i.e. it's not +// closed). +func (c *Auto) waitUnpause(routineLogName string) (keepRunning bool) { + c.mu.Lock() + if !c.paused || c.closed { + defer c.mu.Unlock() + return !c.closed } - select { - case g.loggedOutResult <- err: - default: + unpaused := c.unpausedChanLocked() + c.mu.Unlock() + + c.logf("%s: awaiting unpause", routineLogName) + return <-unpaused +} + +// updateRoutine is responsible for informing the server of worthy changes to +// our local state. It runs in its own goroutine. +func (c *Auto) updateRoutine() { + defer close(c.updateDone) + bo := backoff.NewBackoff("updateRoutine", c.logf, 30*time.Second) + + // lastUpdateGenInformed is the value of lastUpdateAt that we've successfully + // informed the server of. + var lastUpdateGenInformed updateGen + + for { + if !c.waitUnpause("updateRoutine") { + c.logf("updateRoutine: exiting") + return + } + c.mu.Lock() + gen := c.lastUpdateGen + ctx := c.mapCtx + needUpdate := gen > 0 && gen != lastUpdateGenInformed && c.loggedIn + c.mu.Unlock() + + if !needUpdate { + // Nothing to do, wait for a signal. + select { + case <-ctx.Done(): + continue + case <-c.updateCh: + continue + } + } + + t0 := c.clock.Now() + err := c.direct.SendUpdate(ctx) + d := time.Since(t0).Round(time.Millisecond) + if err != nil { + if ctx.Err() == nil { + c.direct.logf("lite map update error after %v: %v", d, err) + } + bo.BackOff(ctx, err) + continue + } + bo.BackOff(ctx, nil) + c.direct.logf("[v1] successful lite map update in %v", d) + + lastUpdateGenInformed = gen } } -var _ Client = (*Auto)(nil) +// atomicGen is an atomic int64 generator. It is used to generate monotonically +// increasing numbers for updateGen. +var atomicGen atomic.Int64 + +func nextUpdateGen() updateGen { + return updateGen(atomicGen.Add(1)) +} + +// updateGen is a monotonically increasing number that represents a particular +// update to the local state. +type updateGen int64 // Auto connects to a tailcontrol server for a node. // It's a concrete implementation of the Client interface. type Auto struct { - direct *Direct // our interface to the server APIs - timeNow func() time.Time - logf logger.Logf - expiry *time.Time - closed bool - newMapCh chan struct{} // readable when we must restart a map request - statusFunc func(Status) // called to update Client status; always non-nil + direct *Direct // our interface to the server APIs + clock tstime.Clock + logf logger.Logf + closed bool + updateCh chan struct{} // readable when we should inform the server of a change + observer Observer // called to update Client status; always non-nil + observerQueue execQueue unregisterHealthWatch func() mu sync.Mutex // mutex guards the following fields - paused bool // whether we should stop making HTTP requests - unpauseWaiters []chan struct{} - loggedIn bool // true if currently logged in - loginGoal *LoginGoal // non-nil if some login activity is desired - synced bool // true if our netmap is up-to-date - inPollNetMap bool // true if currently running a PollNetMap - inLiteMapUpdate bool // true if a lite (non-streaming) map request is outstanding - liteMapUpdateCancel context.CancelFunc // cancels a lite map update, may be nil - liteMapUpdateCancels int // how many times we've canceled a lite map update - inSendStatus int // number of sendStatus calls currently in progress - state State + wantLoggedIn bool // whether the user wants to be logged in per last method call + urlToVisit string // the last url we were told to visit + expiry time.Time + + // lastUpdateGen is the gen of last update we had an update worth sending to + // the server. + lastUpdateGen updateGen + + paused bool // whether we should stop making HTTP requests + unpauseWaiters []chan bool // chans that gets sent true (once) on wake, or false on Shutdown + loggedIn bool // true if currently logged in + loginGoal *LoginGoal // non-nil if some login activity is desired + inMapPoll bool // true once we get the first MapResponse in a stream; false when HTTP response ends + state State // TODO(bradfitz): delete this, make it computed by method from other state authCtx context.Context // context used for auth requests - mapCtx context.Context // context used for netmap requests - authCancel func() // cancel the auth context - mapCancel func() // cancel the netmap context - quit chan struct{} // when closed, goroutines should all exit - authDone chan struct{} // when closed, auth goroutine is done - mapDone chan struct{} // when closed, map goroutine is done + mapCtx context.Context // context used for netmap and update requests + authCancel func() // cancel authCtx + mapCancel func() // cancel mapCtx + authDone chan struct{} // when closed, authRoutine is done + mapDone chan struct{} // when closed, mapRoutine is done + updateDone chan struct{} // when closed, updateRoutine is done } // New creates and starts a new Auto. @@ -101,24 +169,24 @@ func NewNoStart(opts Options) (_ *Auto, err error) { } }() - if opts.Status == nil { - return nil, errors.New("missing required Options.Status") + if opts.Observer == nil { + return nil, errors.New("missing required Options.Observer") } if opts.Logf == nil { opts.Logf = func(fmt string, args ...any) {} } - if opts.TimeNow == nil { - opts.TimeNow = time.Now + if opts.Clock == nil { + opts.Clock = tstime.StdClock{} } c := &Auto{ direct: direct, - timeNow: opts.TimeNow, + clock: opts.Clock, logf: opts.Logf, - newMapCh: make(chan struct{}, 1), - quit: make(chan struct{}), + updateCh: make(chan struct{}, 1), authDone: make(chan struct{}), mapDone: make(chan struct{}), - statusFunc: opts.Status, + updateDone: make(chan struct{}), + observer: opts.Observer, } c.authCtx, c.authCancel = context.WithCancel(context.Background()) c.authCtx = sockstats.WithSockStats(c.authCtx, sockstats.LabelControlClientAuto, opts.Logf) @@ -137,21 +205,20 @@ func NewNoStart(opts Options) (_ *Auto, err error) { func (c *Auto) SetPaused(paused bool) { c.mu.Lock() defer c.mu.Unlock() - if paused == c.paused { + if paused == c.paused || c.closed { return } c.logf("setPaused(%v)", paused) c.paused = paused if paused { - // Only cancel the map routine. (The auth routine isn't expensive - // so it's fine to keep it running.) - c.cancelMapLocked() - } else { - for _, ch := range c.unpauseWaiters { - close(ch) - } - c.unpauseWaiters = nil + c.cancelMapCtxLocked() + c.cancelAuthCtxLocked() + return + } + for _, ch := range c.unpauseWaiters { + ch <- true } + c.unpauseWaiters = nil } // Start starts the client's goroutines. @@ -160,85 +227,41 @@ func (c *Auto) SetPaused(paused bool) { func (c *Auto) Start() { go c.authRoutine() go c.mapRoutine() + go c.updateRoutine() } -// sendNewMapRequest either sends a new OmitPeers, non-streaming map request -// (to just send Hostinfo/Netinfo/Endpoints info, while keeping an existing -// streaming response open), or start a new streaming one if necessary. +// updateControl sends a new OmitPeers, non-streaming map request (to just send +// Hostinfo/Netinfo/Endpoints info, while keeping an existing streaming response +// open). // // It should be called whenever there's something new to tell the server. -func (c *Auto) sendNewMapRequest() { +func (c *Auto) updateControl() { + gen := nextUpdateGen() c.mu.Lock() - - // If we're not already streaming a netmap, then tear down everything - // and start a new stream (which starts by sending a new map request) - if !c.inPollNetMap || !c.loggedIn { + if gen < c.lastUpdateGen { + // This update is out of date. c.mu.Unlock() - c.cancelMapSafely() return } + c.lastUpdateGen = gen + c.mu.Unlock() - // If we are already in process of doing a LiteMapUpdate, cancel it and - // try a new one. If this is the 10th time we have done this - // cancelation, tear down everything and start again. - const maxLiteMapUpdateAttempts = 10 - if c.inLiteMapUpdate { - // Always cancel the in-flight lite map update, regardless of - // whether we cancel the streaming map request or not. - c.liteMapUpdateCancel() - c.inLiteMapUpdate = false - - if c.liteMapUpdateCancels >= maxLiteMapUpdateAttempts { - // Not making progress - c.mu.Unlock() - c.cancelMapSafely() - return - } - - // Increment our cancel counter and continue below to start a - // new lite update. - c.liteMapUpdateCancels++ + select { + case c.updateCh <- struct{}{}: + default: } +} - // Otherwise, send a lite update that doesn't keep a - // long-running stream response. +// cancelAuthCtx cancels the existing auth goroutine's context +// & creates a new one, causing it to restart. +func (c *Auto) cancelAuthCtx() { + c.mu.Lock() defer c.mu.Unlock() - c.inLiteMapUpdate = true - ctx, cancel := context.WithTimeout(c.mapCtx, 10*time.Second) - c.liteMapUpdateCancel = cancel - go func() { - defer cancel() - t0 := time.Now() - err := c.direct.SendLiteMapUpdate(ctx) - d := time.Since(t0).Round(time.Millisecond) - - c.mu.Lock() - c.inLiteMapUpdate = false - c.liteMapUpdateCancel = nil - if err == nil { - c.liteMapUpdateCancels = 0 - } - c.mu.Unlock() - - if err == nil { - c.logf("[v1] successful lite map update in %v", d) - return - } - if ctx.Err() == nil { - c.logf("lite map update after %v: %v", d, err) - } - if !errors.Is(ctx.Err(), context.Canceled) { - // Fall back to restarting the long-polling map - // request (the old heavy way) if the lite update - // failed for reasons other than the context being - // canceled. - c.cancelMapSafely() - } - }() + c.cancelAuthCtxLocked() } -func (c *Auto) cancelAuth() { - c.mu.Lock() +// cancelAuthCtxLocked is like cancelAuthCtx, but assumes the caller holds c.mu. +func (c *Auto) cancelAuthCtxLocked() { if c.authCancel != nil { c.authCancel() } @@ -246,66 +269,37 @@ func (c *Auto) cancelAuth() { c.authCtx, c.authCancel = context.WithCancel(context.Background()) c.authCtx = sockstats.WithSockStats(c.authCtx, sockstats.LabelControlClientAuto, c.logf) } - c.mu.Unlock() } -func (c *Auto) cancelMapLocked() { +// cancelMapCtx cancels the context for the existing mapPoll and liteUpdates +// goroutines and creates a new one, causing them to restart. +func (c *Auto) cancelMapCtx() { + c.mu.Lock() + defer c.mu.Unlock() + c.cancelMapCtxLocked() +} + +// cancelMapCtxLocked is like cancelMapCtx, but assumes the caller holds c.mu. +func (c *Auto) cancelMapCtxLocked() { if c.mapCancel != nil { c.mapCancel() } if !c.closed { c.mapCtx, c.mapCancel = context.WithCancel(context.Background()) c.mapCtx = sockstats.WithSockStats(c.mapCtx, sockstats.LabelControlClientAuto, c.logf) - } } -func (c *Auto) cancelMapUnsafely() { +// restartMap cancels the existing mapPoll and liteUpdates, and then starts a +// new one. +func (c *Auto) restartMap() { c.mu.Lock() - c.cancelMapLocked() + c.cancelMapCtxLocked() + synced := c.inMapPoll c.mu.Unlock() -} - -func (c *Auto) cancelMapSafely() { - c.mu.Lock() - defer c.mu.Unlock() - // Always reset our lite map cancels counter if we're canceling - // everything, since we're about to restart with a new map update; this - // allows future calls to sendNewMapRequest to retry sending lite - // updates. - c.liteMapUpdateCancels = 0 - - c.logf("[v1] cancelMapSafely: synced=%v", c.synced) - - if c.inPollNetMap { - // received at least one netmap since the last - // interruption. That means the server has already - // fully processed our last request, which might - // include UpdateEndpoints(). Interrupt it and try - // again. - c.cancelMapLocked() - } else { - // !synced means we either haven't done a netmap - // request yet, or it hasn't answered yet. So the - // server is in an undefined state. If we send - // another netmap request too soon, it might race - // with the last one, and if we're very unlucky, - // the new request will be applied before the old one, - // and the wrong endpoints will get registered. We - // have to tell the client to abort politely, only - // after it receives a response to its existing netmap - // request. - select { - case c.newMapCh <- struct{}{}: - c.logf("[v1] cancelMapSafely: wrote to channel") - default: - // if channel write failed, then there was already - // an outstanding newMapCh request. One is enough, - // since it'll always use the latest endpoints. - c.logf("[v1] cancelMapSafely: channel was full") - } - } + c.logf("[v1] restartMap: synced=%v", synced) + c.updateControl() } func (c *Auto) authRoutine() { @@ -313,23 +307,20 @@ func (c *Auto) authRoutine() { bo := backoff.NewBackoff("authRoutine", c.logf, 30*time.Second) for { + if !c.waitUnpause("authRoutine") { + c.logf("authRoutine: exiting") + return + } c.mu.Lock() goal := c.loginGoal ctx := c.authCtx if goal != nil { - c.logf("[v1] authRoutine: %s; wantLoggedIn=%v", c.state, goal.wantLoggedIn) + c.logf("[v1] authRoutine: %s; wantLoggedIn=%v", c.state, true) } else { c.logf("[v1] authRoutine: %s; goal=nil paused=%v", c.state, c.paused) } c.mu.Unlock() - select { - case <-c.quit: - c.logf("[v1] authRoutine: quit") - return - default: - } - report := func(err error, msg string) { c.logf("[v1] %s: %v", msg, err) // don't send status updates for context errors, @@ -347,146 +338,166 @@ func (c *Auto) authRoutine() { continue } - if !goal.wantLoggedIn { - health.SetAuthRoutineInError(nil) - err := c.direct.TryLogout(ctx) - goal.sendLogoutError(err) - if err != nil { - report(err, "TryLogout") - bo.BackOff(ctx, err) - continue - } - - // success - c.mu.Lock() - c.loggedIn = false - c.loginGoal = nil - c.state = StateNotAuthenticated - c.synced = false - c.mu.Unlock() + c.mu.Lock() + c.urlToVisit = goal.url + if goal.url != "" { + c.state = StateURLVisitRequired + } else { + c.state = StateAuthenticating + } + c.mu.Unlock() - c.sendStatus("authRoutine-wantout", nil, "", nil) - bo.BackOff(ctx, nil) - } else { // ie. goal.wantLoggedIn + var url string + var err error + var f string + if goal.url != "" { + url, err = c.direct.WaitLoginURL(ctx, goal.url) + f = "WaitLoginURL" + } else { + url, err = c.direct.TryLogin(ctx, goal.token, goal.flags) + f = "TryLogin" + } + if err != nil { + health.SetAuthRoutineInError(err) + report(err, f) + bo.BackOff(ctx, err) + continue + } + if url != "" { + // goal.url ought to be empty here. + // However, not all control servers get this right, + // and logging about it here just generates noise. c.mu.Lock() - if goal.url != "" { - c.state = StateURLVisitRequired - } else { - c.state = StateAuthenticating + c.urlToVisit = url + c.loginGoal = &LoginGoal{ + flags: LoginDefault, + url: url, } + c.state = StateURLVisitRequired c.mu.Unlock() - var url string - var err error - var f string - if goal.url != "" { - url, err = c.direct.WaitLoginURL(ctx, goal.url) - f = "WaitLoginURL" + c.sendStatus("authRoutine-url", err, url, nil) + if goal.url == url { + // The server sent us the same URL we already tried, + // backoff to avoid a busy loop. + bo.BackOff(ctx, errors.New("login URL not changing")) } else { - url, err = c.direct.TryLogin(ctx, goal.token, goal.flags) - f = "TryLogin" - } - if err != nil { - health.SetAuthRoutineInError(err) - report(err, f) - bo.BackOff(ctx, err) - continue - } - if url != "" { - // goal.url ought to be empty here. - // However, not all control servers get this right, - // and logging about it here just generates noise. - c.mu.Lock() - c.loginGoal = &LoginGoal{ - wantLoggedIn: true, - flags: LoginDefault, - url: url, - } - c.state = StateURLVisitRequired - c.synced = false - c.mu.Unlock() - - c.sendStatus("authRoutine-url", err, url, nil) - if goal.url == url { - // The server sent us the same URL we already tried, - // backoff to avoid a busy loop. - bo.BackOff(ctx, errors.New("login URL not changing")) - } else { - bo.BackOff(ctx, nil) - } - continue + bo.BackOff(ctx, nil) } + continue + } - // success - health.SetAuthRoutineInError(nil) - c.mu.Lock() - c.loggedIn = true - c.loginGoal = nil - c.state = StateAuthenticated - c.mu.Unlock() + // success + health.SetAuthRoutineInError(nil) + c.mu.Lock() + c.urlToVisit = "" + c.loggedIn = true + c.loginGoal = nil + c.state = StateAuthenticated + c.mu.Unlock() - c.sendStatus("authRoutine-success", nil, "", nil) - c.cancelMapSafely() - bo.BackOff(ctx, nil) - } + c.sendStatus("authRoutine-success", nil, "", nil) + c.restartMap() + bo.BackOff(ctx, nil) } } -// Expiry returns the credential expiration time, or the zero time if -// the expiration time isn't known. Used in tests only. -func (c *Auto) Expiry() *time.Time { +// ExpiryForTests returns the credential expiration time, or the zero value if +// the expiration time isn't known. It's used in tests only. +func (c *Auto) ExpiryForTests() time.Time { c.mu.Lock() defer c.mu.Unlock() return c.expiry } -// Direct returns the underlying direct client object. Used in tests -// only. -func (c *Auto) Direct() *Direct { +// DirectForTest returns the underlying direct client object. +// It's used in tests only. +func (c *Auto) DirectForTest() *Direct { return c.direct } -// unpausedChanLocked returns a new channel that is closed when the -// current Auto pause is unpaused. +// unpausedChanLocked returns a new channel that gets sent +// either a true when unpaused or false on Auto.Shutdown. // // c.mu must be held -func (c *Auto) unpausedChanLocked() <-chan struct{} { - unpaused := make(chan struct{}) +func (c *Auto) unpausedChanLocked() <-chan bool { + unpaused := make(chan bool, 1) c.unpauseWaiters = append(c.unpauseWaiters, unpaused) return unpaused } +// mapRoutineState is the state of Auto.mapRoutine while it's running. +type mapRoutineState struct { + c *Auto + bo *backoff.Backoff +} + +var _ NetmapDeltaUpdater = mapRoutineState{} + +func (mrs mapRoutineState) UpdateFullNetmap(nm *netmap.NetworkMap) { + c := mrs.c + + c.mu.Lock() + ctx := c.mapCtx + c.inMapPoll = true + if c.loggedIn { + c.state = StateSynchronized + } + c.expiry = nm.Expiry + stillAuthed := c.loggedIn + c.logf("[v1] mapRoutine: netmap received: %s", c.state) + c.mu.Unlock() + + if stillAuthed { + c.sendStatus("mapRoutine-got-netmap", nil, "", nm) + } + // Reset the backoff timer if we got a netmap. + mrs.bo.BackOff(ctx, nil) +} + +func (mrs mapRoutineState) UpdateNetmapDelta(muts []netmap.NodeMutation) bool { + c := mrs.c + + c.mu.Lock() + goodState := c.loggedIn && c.inMapPoll + ndu, canDelta := c.observer.(NetmapDeltaUpdater) + c.mu.Unlock() + + if !goodState || !canDelta { + return false + } + + ctx, cancel := context.WithTimeout(c.mapCtx, 2*time.Second) + defer cancel() + + var ok bool + err := c.observerQueue.RunSync(ctx, func() { + ok = ndu.UpdateNetmapDelta(muts) + }) + return err == nil && ok +} + +// mapRoutine is responsible for keeping a read-only streaming connection to the +// control server, and keeping the netmap up to date. func (c *Auto) mapRoutine() { defer close(c.mapDone) - bo := backoff.NewBackoff("mapRoutine", c.logf, 30*time.Second) + mrs := &mapRoutineState{ + c: c, + bo: backoff.NewBackoff("mapRoutine", c.logf, 30*time.Second), + } for { - c.mu.Lock() - if c.paused { - unpaused := c.unpausedChanLocked() - c.mu.Unlock() - c.logf("mapRoutine: awaiting unpause") - select { - case <-unpaused: - c.logf("mapRoutine: unpaused") - case <-c.quit: - c.logf("mapRoutine: quit") - return - } - continue + if !c.waitUnpause("mapRoutine") { + c.logf("mapRoutine: exiting") + return } + + c.mu.Lock() c.logf("[v1] mapRoutine: %s", c.state) loggedIn := c.loggedIn ctx := c.mapCtx c.mu.Unlock() - select { - case <-c.quit: - c.logf("mapRoutine: quit") - return - default: - } - report := func(err error, msg string) { c.logf("[v1] %s: %v", msg, err) err = fmt.Errorf("%s: %w", msg, err) @@ -500,80 +511,32 @@ func (c *Auto) mapRoutine() { if !loggedIn { // Wait for something interesting to happen c.mu.Lock() - c.synced = false - // c.state is set by authRoutine() + c.inMapPoll = false c.mu.Unlock() - select { - case <-ctx.Done(): - c.logf("[v1] mapRoutine: context done.") - case <-c.newMapCh: - c.logf("[v1] mapRoutine: new map needed while idle.") - } - } else { - // Be sure this is false when we're not inside - // PollNetMap, so that cancelMapSafely() can notify - // us correctly. - c.mu.Lock() - c.inPollNetMap = false - c.mu.Unlock() - health.SetInPollNetMap(false) - - err := c.direct.PollNetMap(ctx, func(nm *netmap.NetworkMap) { - health.SetInPollNetMap(true) - c.mu.Lock() - - select { - case <-c.newMapCh: - c.logf("[v1] mapRoutine: new map request during PollNetMap. canceling.") - c.cancelMapLocked() - - // Don't emit this netmap; we're - // about to request a fresh one. - c.mu.Unlock() - return - default: - } - - c.synced = true - c.inPollNetMap = true - if c.loggedIn { - c.state = StateSynchronized - } - exp := nm.Expiry - c.expiry = &exp - stillAuthed := c.loggedIn - state := c.state - - c.mu.Unlock() - - c.logf("[v1] mapRoutine: netmap received: %s", state) - if stillAuthed { - c.sendStatus("mapRoutine-got-netmap", nil, "", nm) - } - }) - - health.SetInPollNetMap(false) - c.mu.Lock() - c.synced = false - c.inPollNetMap = false - if c.state == StateSynchronized { - c.state = StateAuthenticated - } - paused := c.paused - c.mu.Unlock() + <-ctx.Done() + c.logf("[v1] mapRoutine: context done.") + continue + } + health.SetOutOfPollNetMap() - if paused { - c.logf("mapRoutine: paused") - continue - } + err := c.direct.PollNetMap(ctx, mrs) - if err != nil { - report(err, "PollNetMap") - bo.BackOff(ctx, err) - continue - } - bo.BackOff(ctx, nil) + health.SetOutOfPollNetMap() + c.mu.Lock() + c.inMapPoll = false + if c.state == StateSynchronized { + c.state = StateAuthenticated + } + paused := c.paused + c.mu.Unlock() + + if paused { + mrs.bo.BackOff(ctx, nil) + c.logf("mapRoutine: paused") + } else { + mrs.bo.BackOff(ctx, err) + report(err, "PollNetMap") } } } @@ -598,7 +561,7 @@ func (c *Auto) SetHostinfo(hi *tailcfg.Hostinfo) { } // Send new Hostinfo to server - c.sendNewMapRequest() + c.updateControl() } func (c *Auto) SetNetInfo(ni *tailcfg.NetInfo) { @@ -610,14 +573,20 @@ func (c *Auto) SetNetInfo(ni *tailcfg.NetInfo) { } // Send new NetInfo to server - c.sendNewMapRequest() + c.updateControl() } // SetTKAHead updates the TKA head hash that map-request infrastructure sends. func (c *Auto) SetTKAHead(headHash string) { - c.direct.SetTKAHead(headHash) + if !c.direct.SetTKAHead(headHash) { + return + } + + // Send new TKAHead to server + c.updateControl() } +// sendStatus can not be called with the c.mu held. func (c *Auto) sendStatus(who string, err error, url string, nm *netmap.NetworkMap) { c.mu.Lock() if c.closed { @@ -626,92 +595,77 @@ func (c *Auto) sendStatus(who string, err error, url string, nm *netmap.NetworkM } state := c.state loggedIn := c.loggedIn - synced := c.synced - c.inSendStatus++ + inMapPoll := c.inMapPoll c.mu.Unlock() c.logf("[v1] sendStatus: %s: %v", who, state) - var p *persist.PersistView - var loginFin, logoutFin *empty.Message - if state == StateAuthenticated { - loginFin = new(empty.Message) - } - if state == StateNotAuthenticated { - logoutFin = new(empty.Message) - } - if nm != nil && loggedIn && synced { - pp := c.direct.GetPersist() - p = &pp + var p persist.PersistView + if nm != nil && loggedIn && inMapPoll { + p = c.direct.GetPersist() } else { // don't send netmap status, as it's misleading when we're // not logged in. nm = nil } new := Status{ - LoginFinished: loginFin, - LogoutFinished: logoutFin, - URL: url, - Persist: p, - NetMap: nm, - State: state, - Err: err, + URL: url, + Persist: p, + NetMap: nm, + Err: err, + state: state, } - c.statusFunc(new) - c.mu.Lock() - c.inSendStatus-- - c.mu.Unlock() + // Launch a new goroutine to avoid blocking the caller while the observer + // does its thing, which may result in a call back into the client. + c.observerQueue.Add(func() { + c.observer.SetControlClientStatus(c, new) + }) } func (c *Auto) Login(t *tailcfg.Oauth2Token, flags LoginFlags) { c.logf("client.Login(%v, %v)", t != nil, flags) c.mu.Lock() - c.loginGoal = &LoginGoal{ - wantLoggedIn: true, - token: t, - flags: flags, + defer c.mu.Unlock() + if c.closed { + return } - c.mu.Unlock() - - c.cancelAuth() -} - -func (c *Auto) StartLogout() { - c.logf("client.StartLogout()") - - c.mu.Lock() + c.wantLoggedIn = true c.loginGoal = &LoginGoal{ - wantLoggedIn: false, + token: t, + flags: flags, } - c.mu.Unlock() - c.cancelAuth() + c.cancelMapCtxLocked() + c.cancelAuthCtxLocked() } +var ErrClientClosed = errors.New("client closed") + func (c *Auto) Logout(ctx context.Context) error { c.logf("client.Logout()") - - errc := make(chan error, 1) - c.mu.Lock() - c.loginGoal = &LoginGoal{ - wantLoggedIn: false, - loggedOutResult: errc, - } + c.wantLoggedIn = false + c.loginGoal = nil + closed := c.closed c.mu.Unlock() - c.cancelAuth() - timer := time.NewTimer(10 * time.Second) - defer timer.Stop() - select { - case err := <-errc: + if closed { + return ErrClientClosed + } + + if err := c.direct.TryLogout(ctx); err != nil { return err - case <-ctx.Done(): - return ctx.Err() - case <-timer.C: - return context.DeadlineExceeded } + c.mu.Lock() + c.loggedIn = false + c.state = StateNotAuthenticated + c.cancelAuthCtxLocked() + c.cancelMapCtxLocked() + c.mu.Unlock() + + c.sendStatus("authRoutine-wantout", nil, "", nil) + return nil } func (c *Auto) SetExpirySooner(ctx context.Context, expiry time.Time) error { @@ -725,7 +679,7 @@ func (c *Auto) SetExpirySooner(ctx context.Context, expiry time.Time) error { func (c *Auto) UpdateEndpoints(endpoints []tailcfg.Endpoint) { changed := c.direct.SetEndpoints(endpoints) if changed { - c.sendNewMapRequest() + c.updateControl() } } @@ -733,25 +687,32 @@ func (c *Auto) Shutdown() { c.logf("client.Shutdown()") c.mu.Lock() - inSendStatus := c.inSendStatus closed := c.closed direct := c.direct if !closed { c.closed = true + c.observerQueue.shutdown() + c.cancelAuthCtxLocked() + c.cancelMapCtxLocked() + for _, w := range c.unpauseWaiters { + w <- false + } + c.unpauseWaiters = nil } c.mu.Unlock() - c.logf("client.Shutdown: inSendStatus=%v", inSendStatus) + c.logf("client.Shutdown") if !closed { c.unregisterHealthWatch() - close(c.quit) - c.cancelAuth() <-c.authDone - c.cancelMapUnsafely() <-c.mapDone + <-c.updateDone if direct != nil { direct.Close() } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + c.observerQueue.wait(ctx) c.logf("Client.Shutdown done.") } } @@ -770,7 +731,7 @@ func (c *Auto) TestOnlySetAuthKey(authkey string) { } func (c *Auto) TestOnlyTimeNow() time.Time { - return c.timeNow() + return c.clock.Now() } // SetDNS sends the SetDNSRequest request to the control plane server, @@ -792,3 +753,95 @@ func (c *Auto) DoNoiseRequest(req *http.Request) (*http.Response, error) { func (c *Auto) GetSingleUseNoiseRoundTripper(ctx context.Context) (http.RoundTripper, *tailcfg.EarlyNoise, error) { return c.direct.GetSingleUseNoiseRoundTripper(ctx) } + +type execQueue struct { + mu sync.Mutex + closed bool + inFlight bool // whether a goroutine is running q.run + doneWaiter chan struct{} // non-nil if waiter is waiting, then closed + queue []func() +} + +func (q *execQueue) Add(f func()) { + q.mu.Lock() + defer q.mu.Unlock() + if q.closed { + return + } + if q.inFlight { + q.queue = append(q.queue, f) + } else { + q.inFlight = true + go q.run(f) + } +} + +// RunSync waits for the queue to be drained and then synchronously runs f. +// It returns an error if the queue is closed before f is run or ctx expires. +func (q *execQueue) RunSync(ctx context.Context, f func()) error { + for { + if err := q.wait(ctx); err != nil { + return err + } + q.mu.Lock() + if q.inFlight { + q.mu.Unlock() + continue + } + defer q.mu.Unlock() + if q.closed { + return errors.New("closed") + } + f() + return nil + } +} + +func (q *execQueue) run(f func()) { + f() + + q.mu.Lock() + for len(q.queue) > 0 && !q.closed { + f := q.queue[0] + q.queue[0] = nil + q.queue = q.queue[1:] + q.mu.Unlock() + f() + q.mu.Lock() + } + q.inFlight = false + q.queue = nil + if q.doneWaiter != nil { + close(q.doneWaiter) + q.doneWaiter = nil + } + q.mu.Unlock() +} + +func (q *execQueue) shutdown() { + q.mu.Lock() + defer q.mu.Unlock() + q.closed = true +} + +// wait waits for the queue to be empty. +func (q *execQueue) wait(ctx context.Context) error { + q.mu.Lock() + waitCh := q.doneWaiter + if q.inFlight && waitCh == nil { + waitCh = make(chan struct{}) + q.doneWaiter = waitCh + } + q.mu.Unlock() + + if waitCh == nil { + return nil + } + + select { + case <-waitCh: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/vendor/tailscale.com/control/controlclient/client.go b/vendor/tailscale.com/control/controlclient/client.go index 80e56790ee..b809f81927 100644 --- a/vendor/tailscale.com/control/controlclient/client.go +++ b/vendor/tailscale.com/control/controlclient/client.go @@ -25,6 +25,9 @@ const ( // Client represents a client connection to the control server. // Currently this is done through a pair of polling https requests in // the Auto client, but that might change eventually. +// +// The Client must be comparable as it is used by the Observer to detect stale +// clients. type Client interface { // Shutdown closes this session, which should not be used any further // afterwards. @@ -34,10 +37,6 @@ type Client interface { // LoginFinished flag (on success) or an auth URL (if further // interaction is needed). Login(*tailcfg.Oauth2Token, LoginFlags) - // StartLogout starts an asynchronous logout process. - // When it finishes, the Status callback will be called while - // AuthCantContinue()==true. - StartLogout() // Logout starts a synchronous logout process. It doesn't return // until the logout operation has been completed. Logout(context.Context) error diff --git a/vendor/tailscale.com/control/controlclient/debug.go b/vendor/tailscale.com/control/controlclient/debug.go deleted file mode 100644 index 288abfab4c..0000000000 --- a/vendor/tailscale.com/control/controlclient/debug.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlclient - -import ( - "bytes" - "compress/gzip" - "context" - "log" - "net/http" - "time" - - "tailscale.com/util/goroutines" -) - -func dumpGoroutinesToURL(c *http.Client, targetURL string) { - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - defer cancel() - - zbuf := new(bytes.Buffer) - zw := gzip.NewWriter(zbuf) - zw.Write(goroutines.ScrubbedGoroutineDump()) - zw.Close() - - req, err := http.NewRequestWithContext(ctx, "PUT", targetURL, zbuf) - if err != nil { - log.Printf("dumpGoroutinesToURL: %v", err) - return - } - req.Header.Set("Content-Encoding", "gzip") - t0 := time.Now() - _, err = c.Do(req) - d := time.Since(t0).Round(time.Millisecond) - if err != nil { - log.Printf("dumpGoroutinesToURL error: %v to %v (after %v)", err, targetURL, d) - } else { - log.Printf("dumpGoroutinesToURL complete to %v (after %v)", targetURL, d) - } -} diff --git a/vendor/tailscale.com/control/controlclient/direct.go b/vendor/tailscale.com/control/controlclient/direct.go index 006f2614a3..a19c030b4a 100644 --- a/vendor/tailscale.com/control/controlclient/direct.go +++ b/vendor/tailscale.com/control/controlclient/direct.go @@ -22,6 +22,7 @@ import ( "os" "reflect" "runtime" + "slices" "strings" "sync" "time" @@ -32,7 +33,6 @@ import ( "tailscale.com/health" "tailscale.com/hostinfo" "tailscale.com/ipn/ipnstate" - "tailscale.com/log/logheap" "tailscale.com/logtail" "tailscale.com/net/dnscache" "tailscale.com/net/dnsfallback" @@ -42,14 +42,15 @@ import ( "tailscale.com/net/tlsdial" "tailscale.com/net/tsdial" "tailscale.com/net/tshttpproxy" - "tailscale.com/syncs" + "tailscale.com/smallzstd" "tailscale.com/tailcfg" "tailscale.com/tka" + "tailscale.com/tstime" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netmap" - "tailscale.com/types/opt" "tailscale.com/types/persist" + "tailscale.com/types/ptr" "tailscale.com/types/tkatype" "tailscale.com/util/clientmetric" "tailscale.com/util/multierr" @@ -59,26 +60,24 @@ import ( // Direct is the client that connects to a tailcontrol server for a node. type Direct struct { - httpc *http.Client // HTTP client used to talk to tailcontrol - dialer *tsdial.Dialer - dnsCache *dnscache.Resolver - serverURL string // URL of the tailcontrol server - timeNow func() time.Time - lastPrintMap time.Time - newDecompressor func() (Decompressor, error) - keepAlive bool - logf logger.Logf - netMon *netmon.Monitor // or nil - discoPubKey key.DiscoPublic - getMachinePrivKey func() (key.MachinePrivate, error) - debugFlags []string - keepSharerAndUserSplit bool - skipIPForwardingCheck bool - pinger Pinger - popBrowser func(url string) // or nil - c2nHandler http.Handler // or nil - onClientVersion func(*tailcfg.ClientVersion) // or nil - onControlTime func(time.Time) // or nil + httpc *http.Client // HTTP client used to talk to tailcontrol + dialer *tsdial.Dialer + dnsCache *dnscache.Resolver + controlKnobs *controlknobs.Knobs // always non-nil + serverURL string // URL of the tailcontrol server + clock tstime.Clock + lastPrintMap time.Time + logf logger.Logf + netMon *netmon.Monitor // or nil + discoPubKey key.DiscoPublic + getMachinePrivKey func() (key.MachinePrivate, error) + debugFlags []string + skipIPForwardingCheck bool + pinger Pinger + popBrowser func(url string) // or nil + c2nHandler http.Handler // or nil + onClientVersion func(*tailcfg.ClientVersion) // or nil + onControlTime func(time.Time) // or nil dialPlan ControlDialPlanner // can be nil @@ -92,7 +91,7 @@ type Direct struct { persist persist.PersistView authKey string tryingNewKey key.NodePrivate - expiry *time.Time + expiry time.Time // or zero value if none/unknown hostinfo *tailcfg.Hostinfo // always non-nil netinfo *tailcfg.NetInfo endpoints []tailcfg.Endpoint @@ -100,16 +99,24 @@ type Direct struct { lastPingURL string // last PingRequest.URL received, for dup suppression } +// Observer is implemented by users of the control client (such as LocalBackend) +// to get notified of changes in the control client's status. +type Observer interface { + // SetControlClientStatus is called when the client has a new status to + // report. The Client is provided to allow the Observer to track which + // Client is reporting the status, allowing it to ignore stale status + // reports from previous Clients. + SetControlClientStatus(Client, Status) +} + type Options struct { Persist persist.Persist // initial persistent data GetMachinePrivateKey func() (key.MachinePrivate, error) // returns the machine key to use ServerURL string // URL of the tailcontrol server AuthKey string // optional node auth key for auto registration - TimeNow func() time.Time // time.Now implementation used by Client - Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc + Clock tstime.Clock + Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc DiscoPublicKey key.DiscoPublic - NewDecompressor func() (Decompressor, error) - KeepAlive bool Logf logger.Logf HTTPTestClient *http.Client // optional HTTP client to use (for tests only) NoiseTestClient *http.Client // optional HTTP client to use for noise RPCs (tests only) @@ -120,13 +127,11 @@ type Options struct { OnControlTime func(time.Time) // optional func to notify callers of new time from control Dialer *tsdial.Dialer // non-nil C2NHandler http.Handler // or nil + ControlKnobs *controlknobs.Knobs // or nil to ignore - // Status is called when there's a change in status. - Status func(Status) - - // KeepSharerAndUserSplit controls whether the client - // understands Node.Sharer. If false, the Sharer is mapped to the User. - KeepSharerAndUserSplit bool + // Observer is called when there's a change in status to report + // from the control client. + Observer Observer // SkipIPForwardingCheck declares that the host's IP // forwarding works and should not be double-checked by the @@ -170,7 +175,7 @@ type ControlDialPlanner interface { // Pinger is the LocalBackend.Ping method. type Pinger interface { // Ping is a request to do a ping with the peer handling the given IP. - Ping(ctx context.Context, ip netip.Addr, pingType tailcfg.PingType) (*ipnstate.PingResult, error) + Ping(ctx context.Context, ip netip.Addr, pingType tailcfg.PingType, size int) (*ipnstate.PingResult, error) } type Decompressor interface { @@ -178,6 +183,29 @@ type Decompressor interface { Close() } +// NetmapUpdater is the interface needed by the controlclient to enact change in +// the world as a function of updates received from the network. +type NetmapUpdater interface { + UpdateFullNetmap(*netmap.NetworkMap) + + // TODO(bradfitz): add methods to do fine-grained updates, mutating just + // parts of peers, without implementations of NetmapUpdater needing to do + // the diff themselves between the previous full & next full network maps. +} + +// NetmapDeltaUpdater is an optional interface that can be implemented by +// NetmapUpdater implementations to receive delta updates from the controlclient +// rather than just full updates. +type NetmapDeltaUpdater interface { + // UpdateNetmapDelta is called with discrete changes to the network map. + // + // The ok result is whether the implementation was able to apply the + // mutations. It might return false if its internal state doesn't + // support applying them or a NetmapUpdater it's wrapping doesn't + // implement the NetmapDeltaUpdater optional method. + UpdateNetmapDelta([]netmap.NodeMutation) (ok bool) +} + // NewDirect returns a new Direct client. func NewDirect(opts Options) (*Direct, error) { if opts.ServerURL == "" { @@ -186,13 +214,16 @@ func NewDirect(opts Options) (*Direct, error) { if opts.GetMachinePrivateKey == nil { return nil, errors.New("controlclient.New: no GetMachinePrivateKey specified") } + if opts.ControlKnobs == nil { + opts.ControlKnobs = &controlknobs.Knobs{} + } opts.ServerURL = strings.TrimRight(opts.ServerURL, "/") serverURL, err := url.Parse(opts.ServerURL) if err != nil { return nil, err } - if opts.TimeNow == nil { - opts.TimeNow = time.Now + if opts.Clock == nil { + opts.Clock = tstime.StdClock{} } if opts.Logf == nil { // TODO(apenwarr): remove this default and fail instead. @@ -232,36 +263,32 @@ func NewDirect(opts Options) (*Direct, error) { } c := &Direct{ - httpc: httpc, - getMachinePrivKey: opts.GetMachinePrivateKey, - serverURL: opts.ServerURL, - timeNow: opts.TimeNow, - logf: opts.Logf, - newDecompressor: opts.NewDecompressor, - keepAlive: opts.KeepAlive, - persist: opts.Persist.View(), - authKey: opts.AuthKey, - discoPubKey: opts.DiscoPublicKey, - debugFlags: opts.DebugFlags, - keepSharerAndUserSplit: opts.KeepSharerAndUserSplit, - netMon: opts.NetMon, - skipIPForwardingCheck: opts.SkipIPForwardingCheck, - pinger: opts.Pinger, - popBrowser: opts.PopBrowserURL, - onClientVersion: opts.OnClientVersion, - onControlTime: opts.OnControlTime, - c2nHandler: opts.C2NHandler, - dialer: opts.Dialer, - dnsCache: dnsCache, - dialPlan: opts.DialPlan, + httpc: httpc, + controlKnobs: opts.ControlKnobs, + getMachinePrivKey: opts.GetMachinePrivateKey, + serverURL: opts.ServerURL, + clock: opts.Clock, + logf: opts.Logf, + persist: opts.Persist.View(), + authKey: opts.AuthKey, + discoPubKey: opts.DiscoPublicKey, + debugFlags: opts.DebugFlags, + netMon: opts.NetMon, + skipIPForwardingCheck: opts.SkipIPForwardingCheck, + pinger: opts.Pinger, + popBrowser: opts.PopBrowserURL, + onClientVersion: opts.OnClientVersion, + onControlTime: opts.OnControlTime, + c2nHandler: opts.C2NHandler, + dialer: opts.Dialer, + dnsCache: dnsCache, + dialPlan: opts.DialPlan, } if opts.Hostinfo == nil { c.SetHostinfo(hostinfo.New()) } else { - ni := opts.Hostinfo.NetInfo - opts.Hostinfo.NetInfo = nil c.SetHostinfo(opts.Hostinfo) - if ni != nil { + if ni := opts.Hostinfo.NetInfo; ni != nil { c.SetNetInfo(ni) } } @@ -293,6 +320,8 @@ func (c *Direct) SetHostinfo(hi *tailcfg.Hostinfo) bool { if hi == nil { panic("nil Hostinfo") } + hi = ptr.To(*hi) + hi.NetInfo = nil c.mu.Lock() defer c.mu.Unlock() @@ -432,7 +461,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new authKey, isWrapped, wrappedSig, wrappedKey := decodeWrappedAuthkey(c.authKey, c.logf) hi := c.hostInfoLocked() backendLogID := hi.BackendLogID - expired := c.expiry != nil && !c.expiry.IsZero() && c.expiry.Before(c.timeNow()) + expired := !c.expiry.IsZero() && c.expiry.Before(c.clock.Now()) c.mu.Unlock() machinePrivKey, err := c.getMachinePrivKey() @@ -537,7 +566,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new err = errors.New("hostinfo: BackendLogID missing") return regen, opt.URL, nil, err } - now := time.Now().Round(time.Second) + now := c.clock.Now().Round(time.Second) request := tailcfg.RegisterRequest{ Version: 1, OldNodeKey: oldNodeKey, @@ -559,7 +588,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new request.NodeKey.ShortString(), opt.URL != "", len(nodeKeySignature) > 0) request.Auth.Oauth2Token = opt.Token request.Auth.Provider = persist.Provider - request.Auth.LoginName = persist.LoginName + request.Auth.LoginName = persist.UserProfile.LoginName request.Auth.AuthKey = authKey err = signRegisterRequest(&request, c.serverURL, c.serverKey, machinePrivKey.Public()) if err != nil { @@ -645,9 +674,6 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new if resp.Login.Provider != "" { persist.Provider = resp.Login.Provider } - if resp.Login.LoginName != "" { - persist.LoginName = resp.Login.LoginName - } persist.UserProfile = tailcfg.UserProfile{ ID: resp.User.ID, DisplayName: resp.Login.DisplayName, @@ -726,18 +752,6 @@ func resignNKS(priv key.NLPrivate, nodeKey key.NodePublic, oldNKS tkatype.Marsha return newSig.Serialize(), nil } -func sameEndpoints(a, b []tailcfg.Endpoint) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if a[i] != b[i] { - return false - } - } - return true -} - // newEndpoints acquires c.mu and sets the local port and endpoints and reports // whether they've changed. // @@ -747,15 +761,11 @@ func (c *Direct) newEndpoints(endpoints []tailcfg.Endpoint) (changed bool) { defer c.mu.Unlock() // Nothing new? - if sameEndpoints(c.endpoints, endpoints) { + if slices.Equal(c.endpoints, endpoints) { return false // unchanged } - var epStrs []string - for _, ep := range endpoints { - epStrs = append(epStrs, ep.Addr.String()) - } - c.logf("[v2] client.newEndpoints(%v)", epStrs) - c.endpoints = append(c.endpoints[:0], endpoints...) + c.logf("[v2] client.newEndpoints(%v)", endpoints) + c.endpoints = slices.Clone(endpoints) return true // changed } @@ -768,42 +778,60 @@ func (c *Direct) SetEndpoints(endpoints []tailcfg.Endpoint) (changed bool) { return c.newEndpoints(endpoints) } -// PollNetMap makes a /map request to download the network map, calling cb with -// each new netmap. -func (c *Direct) PollNetMap(ctx context.Context, cb func(*netmap.NetworkMap)) error { - return c.sendMapRequest(ctx, -1, false, cb) +// PollNetMap makes a /map request to download the network map, calling +// NetmapUpdater on each update from the control plane. +// +// It always returns a non-nil error describing the reason for the failure or +// why the request ended. +func (c *Direct) PollNetMap(ctx context.Context, nu NetmapUpdater) error { + return c.sendMapRequest(ctx, true, nu) } -// FetchNetMap fetches the netmap once. -func (c *Direct) FetchNetMap(ctx context.Context) (*netmap.NetworkMap, error) { - var ret *netmap.NetworkMap - err := c.sendMapRequest(ctx, 1, false, func(nm *netmap.NetworkMap) { - ret = nm - }) - if err == nil && ret == nil { +type rememberLastNetmapUpdater struct { + last *netmap.NetworkMap +} + +func (nu *rememberLastNetmapUpdater) UpdateFullNetmap(nm *netmap.NetworkMap) { + nu.last = nm +} + +// FetchNetMapForTest fetches the netmap once. +func (c *Direct) FetchNetMapForTest(ctx context.Context) (*netmap.NetworkMap, error) { + var nu rememberLastNetmapUpdater + err := c.sendMapRequest(ctx, false, &nu) + if err == nil && nu.last == nil { return nil, errors.New("[unexpected] sendMapRequest success without callback") } - return ret, err + return nu.last, err } -// SendLiteMapUpdate makes a /map request to update the server of our latest state, -// but does not fetch anything. It returns an error if the server did not return a +// SendUpdate makes a /map request to update the server of our latest state, but +// does not fetch anything. It returns an error if the server did not return a // successful 200 OK response. -func (c *Direct) SendLiteMapUpdate(ctx context.Context) error { - return c.sendMapRequest(ctx, 1, false, nil) +func (c *Direct) SendUpdate(ctx context.Context) error { + return c.sendMapRequest(ctx, false, nil) } -// If we go more than pollTimeout without hearing from the server, +// If we go more than watchdogTimeout without hearing from the server, // end the long poll. We should be receiving a keep alive ping // every minute. -const pollTimeout = 120 * time.Second +const watchdogTimeout = 120 * time.Second + +// sendMapRequest makes a /map request to download the network map, calling cb +// with each new netmap. If isStreaming, it will poll forever and only returns +// if the context expires or the server returns an error/closes the connection +// and as such always returns a non-nil error. +// +// If cb is nil, OmitPeers will be set to true. +func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu NetmapUpdater) error { + if isStreaming && nu == nil { + panic("cb must be non-nil if isStreaming is true") + } -// cb nil means to omit peers. -func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool, cb func(*netmap.NetworkMap)) error { metricMapRequests.Add(1) metricMapRequestsActive.Add(1) defer metricMapRequestsActive.Add(-1) - if maxPolls == -1 { + if isStreaming { metricMapRequestsPoll.Add(1) } else { metricMapRequestsLite.Add(1) @@ -839,8 +867,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool return errors.New("hostinfo: BackendLogID missing") } - allowStream := maxPolls != 1 - c.logf("[v1] PollNetMap: stream=%v ep=%v", allowStream, epStrs) + c.logf("[v1] PollNetMap: stream=%v ep=%v", isStreaming, epStrs) vlogf := logger.Discard if DevKnob.DumpNetMaps() { @@ -851,28 +878,16 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool request := &tailcfg.MapRequest{ Version: tailcfg.CurrentCapabilityVersion, - KeepAlive: c.keepAlive, + KeepAlive: true, NodeKey: persist.PublicNodeKey(), DiscoKey: c.discoPubKey, Endpoints: epStrs, EndpointTypes: epTypes, - Stream: allowStream, + Stream: isStreaming, Hostinfo: hi, DebugFlags: c.debugFlags, - OmitPeers: cb == nil, + OmitPeers: nu == nil, TKAHead: c.tkaHead, - - // Previously we'd set ReadOnly to true if we didn't have any endpoints - // yet as we expected to learn them in a half second and restart the full - // streaming map poll, however as we are trying to reduce the number of - // times we restart the full streaming map poll we now just set ReadOnly - // false when we're doing a full streaming map poll. - // - // TODO(maisem/bradfitz): really ReadOnly should be set to true if for - // all streams and we should only do writes via lite map updates. - // However that requires an audit and a bunch of testing to make sure we - // don't break anything. - ReadOnly: readOnly && !allowStream, } var extraDebugFlags []string if hi != nil && c.netMon != nil && !c.skipIPForwardingCheck && @@ -890,9 +905,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool old := request.DebugFlags request.DebugFlags = append(old[:len(old):len(old)], extraDebugFlags...) } - if c.newDecompressor != nil { - request.Compress = "zstd" - } + request.Compress = "zstd" bodyData, err := encode(request, serverKey, serverNoiseKey, machinePrivKey) if err != nil { @@ -904,7 +917,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool defer cancel() machinePubKey := machinePrivKey.Public() - t0 := time.Now() + t0 := c.clock.Now() // Url and httpc are protocol specific. var url string @@ -942,45 +955,53 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool health.NoteMapRequestHeard(request) - if cb == nil { + if nu == nil { io.Copy(io.Discard, res.Body) return nil } - timeout := time.NewTimer(pollTimeout) - timeoutReset := make(chan struct{}) - pollDone := make(chan struct{}) - defer close(pollDone) - go func() { - for { - select { - case <-pollDone: - vlogf("netmap: ending timeout goroutine") - return - case <-timeout.C: - c.logf("map response long-poll timed out!") - cancel() - return - case <-timeoutReset: - if !timeout.Stop() { - select { - case <-timeout.C: - case <-pollDone: - vlogf("netmap: ending timeout goroutine") - return - } - } - vlogf("netmap: reset timeout timer") - timeout.Reset(pollTimeout) - } - } - }() + var mapResIdx int // 0 for first message, then 1+ for deltas - sess := newMapSession(persist.PrivateNodeKey()) + sess := newMapSession(persist.PrivateNodeKey(), nu, c.controlKnobs) + defer sess.Close() + sess.cancel = cancel sess.logf = c.logf sess.vlogf = vlogf + sess.altClock = c.clock sess.machinePubKey = machinePubKey - sess.keepSharerAndUserSplit = c.keepSharerAndUserSplit + sess.onDebug = c.handleDebugMessage + sess.onConciseNetMapSummary = func(summary string) { + // Occasionally print the netmap header. + // This is handy for debugging, and our logs processing + // pipeline depends on it. (TODO: Remove this dependency.) + now := c.clock.Now() + if now.Sub(c.lastPrintMap) < 5*time.Minute { + return + } + c.lastPrintMap = now + c.logf("[v1] new network map[%d]:\n%s", mapResIdx, summary) + } + sess.onSelfNodeChanged = func(nm *netmap.NetworkMap) { + c.mu.Lock() + defer c.mu.Unlock() + // If we are the ones who last updated persist, then we can update it + // again. Otherwise, we should not touch it. Also, it's only worth + // change it if the Node info changed. + if persist == c.persist { + newPersist := persist.AsStruct() + newPersist.NodeID = nm.SelfNode.StableID() + newPersist.UserProfile = nm.UserProfiles[nm.User()] + + c.persist = newPersist.View() + persist = c.persist + } + c.expiry = nm.Expiry + } + sess.StartWatchdog() + + // gotNonKeepAliveMessage is whether we've yet received a MapResponse message without + // KeepAlive set. + var gotNonKeepAliveMessage bool // If allowStream, then the server will use an HTTP long poll to // return incremental results. There is always one response right @@ -989,8 +1010,8 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool // the same format before just closing the connection. // We can use this same read loop either way. var msg []byte - for i := 0; i < maxPolls || maxPolls < 0; i++ { - vlogf("netmap: starting size read after %v (poll %v)", time.Since(t0).Round(time.Millisecond), i) + for ; mapResIdx == 0 || isStreaming; mapResIdx++ { + vlogf("netmap: starting size read after %v (poll %v)", time.Since(t0).Round(time.Millisecond), mapResIdx) var siz [4]byte if _, err := io.ReadFull(res.Body, siz[:]); err != nil { vlogf("netmap: size read error after %v: %v", time.Since(t0).Round(time.Millisecond), err) @@ -1013,7 +1034,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool metricMapResponseMessages.Add(1) - if allowStream { + if isStreaming { health.GotStreamedMapResponse() } @@ -1054,7 +1075,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool } select { - case timeoutReset <- struct{}{}: + case sess.watchdogReset <- struct{}{}: vlogf("netmap: sent timer reset") case <-ctx.Done(): c.logf("[v1] netmap: not resetting timer; context done: %v", ctx.Err()) @@ -1066,70 +1087,19 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool } metricMapResponseMap.Add(1) - if i > 0 { + if gotNonKeepAliveMessage { + // If we've already seen a non-keep-alive message, this is a delta update. metricMapResponseMapDelta.Add(1) + } else if resp.Node == nil { + // The very first non-keep-alive message should have Node populated. + c.logf("initial MapResponse lacked Node") + return errors.New("initial MapResponse lacked node") } + gotNonKeepAliveMessage = true - hasDebug := resp.Debug != nil - // being conservative here, if Debug not present set to False - controlknobs.SetDisableUPnP(hasDebug && resp.Debug.DisableUPnP.EqualBool(true)) - if hasDebug { - if code := resp.Debug.Exit; code != nil { - c.logf("exiting process with status %v per controlplane", *code) - os.Exit(*code) - } - if resp.Debug.DisableLogTail { - logtail.Disable() - envknob.SetNoLogsNoSupport() - } - if resp.Debug.LogHeapPprof { - go logheap.LogHeap(resp.Debug.LogHeapURL) - } - if resp.Debug.GoroutineDumpURL != "" { - go dumpGoroutinesToURL(c.httpc, resp.Debug.GoroutineDumpURL) - } - if sleep := time.Duration(resp.Debug.SleepSeconds * float64(time.Second)); sleep > 0 { - if err := sleepAsRequested(ctx, c.logf, timeoutReset, sleep); err != nil { - return err - } - } - } - - nm := sess.netmapForResponse(&resp) - if nm.SelfNode == nil { - c.logf("MapResponse lacked node") - return errors.New("MapResponse lacked node") - } - - if d := nm.Debug; d != nil { - controlUseDERPRoute.Store(d.DERPRoute) - controlTrimWGConfig.Store(d.TrimWGConfig) - } - - if DevKnob.StripEndpoints() { - for _, p := range resp.Peers { - p.Endpoints = nil - } - } - if DevKnob.StripCaps() { - nm.SelfNode.Capabilities = nil - } - - // Occasionally print the netmap header. - // This is handy for debugging, and our logs processing - // pipeline depends on it. (TODO: Remove this dependency.) - // Code elsewhere prints netmap diffs every time they are received. - now := c.timeNow() - if now.Sub(c.lastPrintMap) >= 5*time.Minute { - c.lastPrintMap = now - c.logf("[v1] new network map[%d]:\n%s", i, nm.VeryConcise()) + if err := sess.HandleNonKeepAliveMapResponse(ctx, &resp); err != nil { + return err } - - c.mu.Lock() - c.expiry = &nm.Expiry - c.mu.Unlock() - - cb(nm) } if ctx.Err() != nil { return ctx.Err() @@ -1137,6 +1107,45 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool return nil } +func (c *Direct) handleDebugMessage(ctx context.Context, debug *tailcfg.Debug, watchdogReset chan<- struct{}) error { + if code := debug.Exit; code != nil { + c.logf("exiting process with status %v per controlplane", *code) + os.Exit(*code) + } + if debug.DisableLogTail { + logtail.Disable() + envknob.SetNoLogsNoSupport() + } + if sleep := time.Duration(debug.SleepSeconds * float64(time.Second)); sleep > 0 { + if err := sleepAsRequested(ctx, c.logf, watchdogReset, sleep, c.clock); err != nil { + return err + } + } + return nil +} + +// initDisplayNames mutates any tailcfg.Nodes in resp to populate their display names, +// calling InitDisplayNames on each. +// +// The magicDNSSuffix used is based on selfNode. +func initDisplayNames(selfNode tailcfg.NodeView, resp *tailcfg.MapResponse) { + if resp.Node == nil && len(resp.Peers) == 0 && len(resp.PeersChanged) == 0 { + // Fast path for a common case (delta updates). No need to compute + // magicDNSSuffix. + return + } + magicDNSSuffix := netmap.MagicDNSSuffixOfNodeName(selfNode.Name()) + if resp.Node != nil { + resp.Node.InitDisplayNames(magicDNSSuffix) + } + for _, n := range resp.Peers { + n.InitDisplayNames(magicDNSSuffix) + } + for _, n := range resp.PeersChanged { + n.InitDisplayNames(magicDNSSuffix) + } +} + // decode JSON decodes the res.Body into v. If serverNoiseKey is not specified, // it uses the serverKey and mkey to decode the message from the NaCl-crypto-box. func decode(res *http.Response, v any, serverKey, serverNoiseKey key.MachinePublic, mkey key.MachinePrivate) error { @@ -1180,19 +1189,14 @@ func (c *Direct) decodeMsg(msg []byte, v any, mkey key.MachinePrivate) error { } else { decrypted = msg } - var b []byte - if c.newDecompressor == nil { - b = decrypted - } else { - decoder, err := c.newDecompressor() - if err != nil { - return err - } - defer decoder.Close() - b, err = decoder.DecodeAll(decrypted, nil) - if err != nil { - return err - } + decoder, err := smallzstd.NewDecoder(nil) + if err != nil { + return err + } + defer decoder.Close() + b, err := decoder.DecodeAll(decrypted, nil) + if err != nil { + return err } if debugMap() { var buf bytes.Buffer @@ -1297,25 +1301,7 @@ func initDevKnob() devKnobs { } } -var clockNow = time.Now - -// opt.Bool configs from control. -var ( - controlUseDERPRoute syncs.AtomicValue[opt.Bool] - controlTrimWGConfig syncs.AtomicValue[opt.Bool] -) - -// DERPRouteFlag reports the last reported value from control for whether -// DERP route optimization (Issue 150) should be enabled. -func DERPRouteFlag() opt.Bool { - return controlUseDERPRoute.Load() -} - -// TrimWGConfig reports the last reported value from control for whether -// we should do lazy wireguard configuration. -func TrimWGConfig() opt.Bool { - return controlTrimWGConfig.Load() -} +var clock tstime.Clock = tstime.StdClock{} // ipForwardingBroken reports whether the system's IP forwarding is disabled // and will definitely not work for the routes provided. @@ -1401,9 +1387,9 @@ func answerHeadPing(logf logger.Logf, c *http.Client, pr *tailcfg.PingRequest) { if pr.Log { logf("answerHeadPing: sending HEAD ping to %v ...", pr.URL) } - t0 := time.Now() + t0 := clock.Now() _, err = c.Do(req) - d := time.Since(t0).Round(time.Millisecond) + d := clock.Since(t0).Round(time.Millisecond) if err != nil { logf("answerHeadPing error: %v to %v (after %v)", err, pr.URL, d) } else if pr.Log { @@ -1449,7 +1435,7 @@ func answerC2NPing(logf logger.Logf, c2nHandler http.Handler, c *http.Client, pr if pr.Log { logf("answerC2NPing: sending POST ping to %v ...", pr.URL) } - t0 := time.Now() + t0 := clock.Now() _, err = c.Do(req) d := time.Since(t0).Round(time.Millisecond) if err != nil { @@ -1459,7 +1445,11 @@ func answerC2NPing(logf logger.Logf, c2nHandler http.Handler, c *http.Client, pr } } -func sleepAsRequested(ctx context.Context, logf logger.Logf, timeoutReset chan<- struct{}, d time.Duration) error { +// sleepAsRequest implements the sleep for a tailcfg.Debug message requesting +// that the client sleep. The complication is that while we're sleeping (if for +// a long time), we need to periodically reset the watchdog timer before it +// expires. +func sleepAsRequested(ctx context.Context, logf logger.Logf, watchdogReset chan<- struct{}, d time.Duration, clock tstime.Clock) error { const maxSleep = 5 * time.Minute if d > maxSleep { logf("sleeping for %v, capped from server-requested %v ...", maxSleep, d) @@ -1468,20 +1458,20 @@ func sleepAsRequested(ctx context.Context, logf logger.Logf, timeoutReset chan<- logf("sleeping for server-requested %v ...", d) } - ticker := time.NewTicker(pollTimeout / 2) + ticker, tickerChannel := clock.NewTicker(watchdogTimeout / 2) defer ticker.Stop() - timer := time.NewTimer(d) + timer, timerChannel := clock.NewTimer(d) defer timer.Stop() for { select { case <-ctx.Done(): return ctx.Err() - case <-timer.C: + case <-timerChannel: return nil - case <-ticker.C: + case <-tickerChannel: select { - case timeoutReset <- struct{}{}: - case <-timer.C: + case watchdogReset <- struct{}{}: + case <-timerChannel: return nil case <-ctx.Done(): return ctx.Err() @@ -1511,7 +1501,7 @@ func (c *Direct) getNoiseClient() (*NoiseClient, error) { if err != nil { return nil, err } - c.logf("creating new noise client") + c.logf("[v1] creating new noise client") nc, err := NewNoiseClient(NoiseOpts{ PrivKey: k, ServerPubKey: serverNoiseKey, @@ -1658,12 +1648,12 @@ func doPingerPing(logf logger.Logf, c *http.Client, pr *tailcfg.PingRequest, pin logf("invalid ping request: missing url, ip or pinger") return } - start := time.Now() + start := clock.Now() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - res, err := pinger.Ping(ctx, pr.IP, pingType) + res, err := pinger.Ping(ctx, pr.IP, pingType, 0) if err != nil { d := time.Since(start).Round(time.Millisecond) logf("doPingerPing: ping error of type %q to %v after %v: %v", pingType, pr.IP, d, err) @@ -1696,7 +1686,7 @@ func postPingResult(start time.Time, logf logger.Logf, c *http.Client, pr *tailc if pr.Log { logf("postPingResult: sending ping results to %v ...", pr.URL) } - t0 := time.Now() + t0 := clock.Now() _, err = c.Do(req) d := time.Since(t0).Round(time.Millisecond) if err != nil { diff --git a/vendor/tailscale.com/control/controlclient/map.go b/vendor/tailscale.com/control/controlclient/map.go index e1161c34a8..8623208b79 100644 --- a/vendor/tailscale.com/control/controlclient/map.go +++ b/vendor/tailscale.com/control/controlclient/map.go @@ -4,18 +4,29 @@ package controlclient import ( + "context" + "encoding/json" "fmt" - "log" + "net" "net/netip" + "reflect" + "slices" "sort" + "strconv" + "sync" + "time" + "tailscale.com/control/controlknobs" "tailscale.com/envknob" "tailscale.com/tailcfg" + "tailscale.com/tstime" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netmap" - "tailscale.com/types/opt" + "tailscale.com/types/ptr" "tailscale.com/types/views" + "tailscale.com/util/clientmetric" + "tailscale.com/util/cmpx" "tailscale.com/wgengine/filter" ) @@ -29,14 +40,42 @@ import ( // one MapRequest). type mapSession struct { // Immutable fields. - privateNodeKey key.NodePrivate - logf logger.Logf - vlogf logger.Logf - machinePubKey key.MachinePublic - keepSharerAndUserSplit bool // see Options.KeepSharerAndUserSplit + netmapUpdater NetmapUpdater // called on changes (in addition to the optional hooks below) + controlKnobs *controlknobs.Knobs // or nil + privateNodeKey key.NodePrivate + publicNodeKey key.NodePublic + logf logger.Logf + vlogf logger.Logf + machinePubKey key.MachinePublic + altClock tstime.Clock // if nil, regular time is used + cancel context.CancelFunc // always non-nil, shuts down caller's base long poll context + watchdogReset chan struct{} // send to request that the long poll activity watchdog timeout be reset + + // sessionAliveCtx is a Background-based context that's alive for the + // duration of the mapSession that we own the lifetime of. It's closed by + // sessionAliveCtxClose. + sessionAliveCtx context.Context + sessionAliveCtxClose context.CancelFunc // closes sessionAliveCtx + + // Optional hooks, set once before use. + + // onDebug specifies what to do with a *tailcfg.Debug message. + // If the watchdogReset chan is nil, it's not used. Otherwise it can be sent to + // to request that the long poll activity watchdog timeout be reset. + onDebug func(_ context.Context, _ *tailcfg.Debug, watchdogReset chan<- struct{}) error + + // onConciseNetMapSummary, if non-nil, is called with the Netmap.VeryConcise summary + // whenever a map response is received. + onConciseNetMapSummary func(string) + + // onSelfNodeChanged is called before the NetmapUpdater if the self node was + // changed. + onSelfNodeChanged func(*netmap.NetworkMap) // Fields storing state over the course of multiple MapResponses. - lastNode *tailcfg.Node + lastNode tailcfg.NodeView + peers map[tailcfg.NodeID]*tailcfg.NodeView // pointer to view (oddly). same pointers as sortedPeers. + sortedPeers []*tailcfg.NodeView // same pointers as peers, but sorted by Node.ID lastDNSConfig *tailcfg.DNSConfig lastDERPMap *tailcfg.DERPMap lastUserProfile map[tailcfg.UserID]tailcfg.UserProfile @@ -44,55 +83,204 @@ type mapSession struct { lastParsedPacketFilter []filter.Match lastSSHPolicy *tailcfg.SSHPolicy collectServices bool - previousPeers []*tailcfg.Node // for delta-purposes lastDomain string lastDomainAuditLogID string lastHealth []string lastPopBrowserURL string stickyDebug tailcfg.Debug // accumulated opt.Bool values lastTKAInfo *tailcfg.TKAInfo - - // netMapBuilding is non-nil during a netmapForResponse call, - // containing the value to be returned, once fully populated. - netMapBuilding *netmap.NetworkMap + lastNetmapSummary string // from NetworkMap.VeryConcise } -func newMapSession(privateNodeKey key.NodePrivate) *mapSession { +// newMapSession returns a mostly unconfigured new mapSession. +// +// Modify its optional fields on the returned value before use. +// +// It must have its Close method called to release resources. +func newMapSession(privateNodeKey key.NodePrivate, nu NetmapUpdater, controlKnobs *controlknobs.Knobs) *mapSession { ms := &mapSession{ + netmapUpdater: nu, + controlKnobs: controlKnobs, privateNodeKey: privateNodeKey, - logf: logger.Discard, - vlogf: logger.Discard, + publicNodeKey: privateNodeKey.Public(), lastDNSConfig: new(tailcfg.DNSConfig), lastUserProfile: map[tailcfg.UserID]tailcfg.UserProfile{}, - } + watchdogReset: make(chan struct{}), + + // Non-nil no-op defaults, to be optionally overridden by the caller. + logf: logger.Discard, + vlogf: logger.Discard, + cancel: func() {}, + onDebug: func(context.Context, *tailcfg.Debug, chan<- struct{}) error { return nil }, + onConciseNetMapSummary: func(string) {}, + onSelfNodeChanged: func(*netmap.NetworkMap) {}, + } + ms.sessionAliveCtx, ms.sessionAliveCtxClose = context.WithCancel(context.Background()) return ms } -func (ms *mapSession) addUserProfile(userID tailcfg.UserID) { - nm := ms.netMapBuilding - if _, dup := nm.UserProfiles[userID]; dup { - // Already populated it from a previous peer. - return +func (ms *mapSession) clock() tstime.Clock { + return cmpx.Or[tstime.Clock](ms.altClock, tstime.StdClock{}) +} + +// StartWatchdog starts the session's watchdog timer. +// If there's no activity in too long, it tears down the connection. +// Call Close to release these resources. +func (ms *mapSession) StartWatchdog() { + timer, timedOutChan := ms.clock().NewTimer(watchdogTimeout) + go func() { + defer timer.Stop() + for { + select { + case <-ms.sessionAliveCtx.Done(): + ms.vlogf("netmap: ending timeout goroutine") + return + case <-timedOutChan: + ms.logf("map response long-poll timed out!") + ms.cancel() + return + case <-ms.watchdogReset: + if !timer.Stop() { + select { + case <-timedOutChan: + case <-ms.sessionAliveCtx.Done(): + ms.vlogf("netmap: ending timeout goroutine") + return + } + } + ms.vlogf("netmap: reset timeout timer") + timer.Reset(watchdogTimeout) + } + } + }() +} + +func (ms *mapSession) Close() { + ms.sessionAliveCtxClose() +} + +// HandleNonKeepAliveMapResponse handles a non-KeepAlive MapResponse (full or +// incremental). +// +// All fields that are valid on a KeepAlive MapResponse have already been +// handled. +// +// TODO(bradfitz): make this handle all fields later. For now (2023-08-20) this +// is [re]factoring progress enough. +func (ms *mapSession) HandleNonKeepAliveMapResponse(ctx context.Context, resp *tailcfg.MapResponse) error { + if debug := resp.Debug; debug != nil { + if err := ms.onDebug(ctx, debug, ms.watchdogReset); err != nil { + return err + } } - if up, ok := ms.lastUserProfile[userID]; ok { - nm.UserProfiles[userID] = up + + if DevKnob.StripEndpoints() { + for _, p := range resp.Peers { + p.Endpoints = nil + } + for _, p := range resp.PeersChanged { + p.Endpoints = nil + } } + + // For responses that mutate the self node, check for updated nodeAttrs. + if resp.Node != nil { + if DevKnob.StripCaps() { + resp.Node.Capabilities = nil + resp.Node.CapMap = nil + } + ms.controlKnobs.UpdateFromNodeAttributes(resp.Node.Capabilities, resp.Node.CapMap) + } + + // Call Node.InitDisplayNames on any changed nodes. + initDisplayNames(cmpx.Or(resp.Node.View(), ms.lastNode), resp) + + ms.patchifyPeersChanged(resp) + + ms.updateStateFromResponse(resp) + + if ms.tryHandleIncrementally(resp) { + ms.onConciseNetMapSummary(ms.lastNetmapSummary) // every 5s log + return nil + } + + // We have to rebuild the whole netmap (lots of garbage & work downstream of + // our UpdateFullNetmap call). This is the part we tried to avoid but + // some field mutations (especially rare ones) aren't yet handled. + + nm := ms.netmap() + ms.lastNetmapSummary = nm.VeryConcise() + ms.onConciseNetMapSummary(ms.lastNetmapSummary) + + // If the self node changed, we might need to update persist. + if resp.Node != nil { + ms.onSelfNodeChanged(nm) + } + + ms.netmapUpdater.UpdateFullNetmap(nm) + return nil } -// netmapForResponse returns a fully populated NetworkMap from a full -// or incremental MapResponse within the session, filling in omitted -// information from prior MapResponse values. -func (ms *mapSession) netmapForResponse(resp *tailcfg.MapResponse) *netmap.NetworkMap { - undeltaPeers(resp, ms.previousPeers) +func (ms *mapSession) tryHandleIncrementally(res *tailcfg.MapResponse) bool { + if ms.controlKnobs != nil && ms.controlKnobs.DisableDeltaUpdates.Load() { + return false + } + nud, ok := ms.netmapUpdater.(NetmapDeltaUpdater) + if !ok { + return false + } + mutations, ok := netmap.MutationsFromMapResponse(res, time.Now()) + if ok && len(mutations) > 0 { + return nud.UpdateNetmapDelta(mutations) + } + return ok +} + +// updateStats are some stats from updateStateFromResponse, primarily for +// testing. It's meant to be cheap enough to always compute, though. It doesn't +// allocate. +type updateStats struct { + allNew bool + added int + removed int + changed int +} + +// updateStateFromResponse updates ms from res. It takes ownership of res. +func (ms *mapSession) updateStateFromResponse(resp *tailcfg.MapResponse) { + ms.updatePeersStateFromResponse(resp) + + if resp.Node != nil { + ms.lastNode = resp.Node.View() + } - ms.previousPeers = cloneNodes(resp.Peers) // defensive/lazy clone, since this escapes to who knows where for _, up := range resp.UserProfiles { ms.lastUserProfile[up.ID] = up } + // TODO(bradfitz): clean up old user profiles? maybe not worth it. - if resp.DERPMap != nil { + if dm := resp.DERPMap; dm != nil { ms.vlogf("netmap: new map contains DERP map") - ms.lastDERPMap = resp.DERPMap + + // Zero-valued fields in a DERPMap mean that we're not changing + // anything and are using the previous value(s). + if ldm := ms.lastDERPMap; ldm != nil { + if dm.Regions == nil { + dm.Regions = ldm.Regions + dm.OmitDefaultRegions = ldm.OmitDefaultRegions + } + if dm.HomeParams == nil { + dm.HomeParams = ldm.HomeParams + } else if oldhh := ldm.HomeParams; oldhh != nil { + // Propagate sub-fields of HomeParams + hh := dm.HomeParams + if hh.RegionScore == nil { + hh.RegionScore = oldhh.RegionScore + } + } + } + + ms.lastDERPMap = dm } if pf := resp.PacketFilter; pf != nil { @@ -125,234 +313,474 @@ func (ms *mapSession) netmapForResponse(resp *tailcfg.MapResponse) *netmap.Netwo if resp.TKAInfo != nil { ms.lastTKAInfo = resp.TKAInfo } +} - debug := resp.Debug - if debug != nil { - if debug.RandomizeClientPort { - debug.SetRandomizeClientPort.Set(true) - } - if debug.ForceBackgroundSTUN { - debug.SetForceBackgroundSTUN.Set(true) +var ( + patchDERPRegion = clientmetric.NewCounter("controlclient_patch_derp") + patchEndpoints = clientmetric.NewCounter("controlclient_patch_endpoints") + patchCap = clientmetric.NewCounter("controlclient_patch_capver") + patchKey = clientmetric.NewCounter("controlclient_patch_key") + patchDiscoKey = clientmetric.NewCounter("controlclient_patch_discokey") + patchOnline = clientmetric.NewCounter("controlclient_patch_online") + patchLastSeen = clientmetric.NewCounter("controlclient_patch_lastseen") + patchKeyExpiry = clientmetric.NewCounter("controlclient_patch_keyexpiry") + patchCapabilities = clientmetric.NewCounter("controlclient_patch_capabilities") + patchCapMap = clientmetric.NewCounter("controlclient_patch_capmap") + patchKeySignature = clientmetric.NewCounter("controlclient_patch_keysig") + + patchifiedPeer = clientmetric.NewCounter("controlclient_patchified_peer") + patchifiedPeerEqual = clientmetric.NewCounter("controlclient_patchified_peer_equal") +) + +// updatePeersStateFromResponseres updates ms.peers and ms.sortedPeers from res. It takes ownership of res. +func (ms *mapSession) updatePeersStateFromResponse(resp *tailcfg.MapResponse) (stats updateStats) { + defer func() { + if stats.removed > 0 || stats.added > 0 { + ms.rebuildSorted() } - copyDebugOptBools(&ms.stickyDebug, debug) - } else if ms.stickyDebug != (tailcfg.Debug{}) { - debug = new(tailcfg.Debug) + }() + + if ms.peers == nil { + ms.peers = make(map[tailcfg.NodeID]*tailcfg.NodeView) } - if debug != nil { - copyDebugOptBools(debug, &ms.stickyDebug) - if !debug.ForceBackgroundSTUN { - debug.ForceBackgroundSTUN, _ = ms.stickyDebug.SetForceBackgroundSTUN.Get() + + if len(resp.Peers) > 0 { + // Not delta encoded. + stats.allNew = true + keep := make(map[tailcfg.NodeID]bool, len(resp.Peers)) + for _, n := range resp.Peers { + keep[n.ID] = true + if vp, ok := ms.peers[n.ID]; ok { + stats.changed++ + *vp = n.View() + } else { + stats.added++ + ms.peers[n.ID] = ptr.To(n.View()) + } } - if !debug.RandomizeClientPort { - debug.RandomizeClientPort, _ = ms.stickyDebug.SetRandomizeClientPort.Get() + for id := range ms.peers { + if !keep[id] { + stats.removed++ + delete(ms.peers, id) + } } + // Peers precludes all other delta operations so just return. + return } - nm := &netmap.NetworkMap{ - NodeKey: ms.privateNodeKey.Public(), - PrivateKey: ms.privateNodeKey, - MachineKey: ms.machinePubKey, - Peers: resp.Peers, - UserProfiles: make(map[tailcfg.UserID]tailcfg.UserProfile), - Domain: ms.lastDomain, - DomainAuditLogID: ms.lastDomainAuditLogID, - DNS: *ms.lastDNSConfig, - PacketFilter: ms.lastParsedPacketFilter, - PacketFilterRules: ms.lastPacketFilterRules, - SSHPolicy: ms.lastSSHPolicy, - CollectServices: ms.collectServices, - DERPMap: ms.lastDERPMap, - Debug: debug, - ControlHealth: ms.lastHealth, - TKAEnabled: ms.lastTKAInfo != nil && !ms.lastTKAInfo.Disabled, - } - ms.netMapBuilding = nm - - if ms.lastTKAInfo != nil && ms.lastTKAInfo.Head != "" { - if err := nm.TKAHead.UnmarshalText([]byte(ms.lastTKAInfo.Head)); err != nil { - ms.logf("error unmarshalling TKAHead: %v", err) - nm.TKAEnabled = false + for _, id := range resp.PeersRemoved { + if _, ok := ms.peers[id]; ok { + delete(ms.peers, id) + stats.removed++ } } - if resp.Node != nil { - ms.lastNode = resp.Node - } - if node := ms.lastNode.Clone(); node != nil { - nm.SelfNode = node - nm.Expiry = node.KeyExpiry - nm.Name = node.Name - nm.Addresses = filterSelfAddresses(node.Addresses) - nm.User = node.User - if node.Hostinfo.Valid() { - nm.Hostinfo = *node.Hostinfo.AsStruct() - } - if node.MachineAuthorized { - nm.MachineStatus = tailcfg.MachineAuthorized + for _, n := range resp.PeersChanged { + if vp, ok := ms.peers[n.ID]; ok { + stats.changed++ + *vp = n.View() } else { - nm.MachineStatus = tailcfg.MachineUnauthorized + stats.added++ + ms.peers[n.ID] = ptr.To(n.View()) } } - ms.addUserProfile(nm.User) - magicDNSSuffix := nm.MagicDNSSuffix() - if nm.SelfNode != nil { - nm.SelfNode.InitDisplayNames(magicDNSSuffix) - } - for _, peer := range resp.Peers { - peer.InitDisplayNames(magicDNSSuffix) - if !peer.Sharer.IsZero() { - if ms.keepSharerAndUserSplit { - ms.addUserProfile(peer.Sharer) + for nodeID, seen := range resp.PeerSeenChange { + if vp, ok := ms.peers[nodeID]; ok { + mut := vp.AsStruct() + if seen { + mut.LastSeen = ptr.To(clock.Now()) } else { - peer.User = peer.Sharer + mut.LastSeen = nil } + *vp = mut.View() + stats.changed++ } - ms.addUserProfile(peer.User) } - if DevKnob.ForceProxyDNS() { - nm.DNS.Proxied = true + + for nodeID, online := range resp.OnlineChange { + if vp, ok := ms.peers[nodeID]; ok { + mut := vp.AsStruct() + mut.Online = ptr.To(online) + *vp = mut.View() + stats.changed++ + } } - ms.netMapBuilding = nil - return nm -} -// undeltaPeers updates mapRes.Peers to be complete based on the -// provided previous peer list and the PeersRemoved and PeersChanged -// fields in mapRes, as well as the PeerSeenChange and OnlineChange -// maps. -// -// It then also nils out the delta fields. -func undeltaPeers(mapRes *tailcfg.MapResponse, prev []*tailcfg.Node) { - if len(mapRes.Peers) > 0 { - // Not delta encoded. - if !nodesSorted(mapRes.Peers) { - log.Printf("netmap: undeltaPeers: MapResponse.Peers not sorted; sorting") - sortNodes(mapRes.Peers) + for _, pc := range resp.PeersChangedPatch { + vp, ok := ms.peers[pc.NodeID] + if !ok { + continue } - return + stats.changed++ + mut := vp.AsStruct() + if pc.DERPRegion != 0 { + mut.DERP = fmt.Sprintf("%s:%v", tailcfg.DerpMagicIP, pc.DERPRegion) + patchDERPRegion.Add(1) + } + if pc.Cap != 0 { + mut.Cap = pc.Cap + patchCap.Add(1) + } + if pc.Endpoints != nil { + mut.Endpoints = pc.Endpoints + patchEndpoints.Add(1) + } + if pc.Key != nil { + mut.Key = *pc.Key + patchKey.Add(1) + } + if pc.DiscoKey != nil { + mut.DiscoKey = *pc.DiscoKey + patchDiscoKey.Add(1) + } + if v := pc.Online; v != nil { + mut.Online = ptr.To(*v) + patchOnline.Add(1) + } + if v := pc.LastSeen; v != nil { + mut.LastSeen = ptr.To(*v) + patchLastSeen.Add(1) + } + if v := pc.KeyExpiry; v != nil { + mut.KeyExpiry = *v + patchKeyExpiry.Add(1) + } + if v := pc.Capabilities; v != nil { + mut.Capabilities = *v + patchCapabilities.Add(1) + } + if v := pc.KeySignature; v != nil { + mut.KeySignature = v + patchKeySignature.Add(1) + } + if v := pc.CapMap; v != nil { + mut.CapMap = v + patchCapMap.Add(1) + } + *vp = mut.View() } - var removed map[tailcfg.NodeID]bool - if pr := mapRes.PeersRemoved; len(pr) > 0 { - removed = make(map[tailcfg.NodeID]bool, len(pr)) - for _, id := range pr { - removed[id] = true + return +} + +// rebuildSorted rebuilds ms.sortedPeers from ms.peers. It should be called +// after any additions or removals from peers. +func (ms *mapSession) rebuildSorted() { + if ms.sortedPeers == nil { + ms.sortedPeers = make([]*tailcfg.NodeView, 0, len(ms.peers)) + } else { + if len(ms.sortedPeers) > len(ms.peers) { + clear(ms.sortedPeers[len(ms.peers):]) } + ms.sortedPeers = ms.sortedPeers[:0] + } + for _, p := range ms.peers { + ms.sortedPeers = append(ms.sortedPeers, p) } - changed := mapRes.PeersChanged + sort.Slice(ms.sortedPeers, func(i, j int) bool { + return ms.sortedPeers[i].ID() < ms.sortedPeers[j].ID() + }) +} - if !nodesSorted(changed) { - log.Printf("netmap: undeltaPeers: MapResponse.PeersChanged not sorted; sorting") - sortNodes(changed) +func (ms *mapSession) addUserProfile(nm *netmap.NetworkMap, userID tailcfg.UserID) { + if userID == 0 { + return } - if !nodesSorted(prev) { - // Internal error (unrelated to the network) if we get here. - log.Printf("netmap: undeltaPeers: [unexpected] prev not sorted; sorting") - sortNodes(prev) + if _, dup := nm.UserProfiles[userID]; dup { + // Already populated it from a previous peer. + return } + if up, ok := ms.lastUserProfile[userID]; ok { + nm.UserProfiles[userID] = up + } +} - newFull := prev - if len(removed) > 0 || len(changed) > 0 { - newFull = make([]*tailcfg.Node, 0, len(prev)-len(removed)) - for len(prev) > 0 && len(changed) > 0 { - pID := prev[0].ID - cID := changed[0].ID - if removed[pID] { - prev = prev[1:] - continue - } - switch { - case pID < cID: - newFull = append(newFull, prev[0]) - prev = prev[1:] - case pID == cID: - newFull = append(newFull, changed[0]) - prev, changed = prev[1:], changed[1:] - case cID < pID: - newFull = append(newFull, changed[0]) - changed = changed[1:] +var debugPatchifyPeer = envknob.RegisterBool("TS_DEBUG_PATCHIFY_PEER") + +// patchifyPeersChanged mutates resp to promote PeersChanged entries to PeersChangedPatch +// when possible. +func (ms *mapSession) patchifyPeersChanged(resp *tailcfg.MapResponse) { + filtered := resp.PeersChanged[:0] + for _, n := range resp.PeersChanged { + if p, ok := ms.patchifyPeer(n); ok { + patchifiedPeer.Add(1) + if debugPatchifyPeer() { + patchj, _ := json.Marshal(p) + ms.logf("debug: patchifyPeer[ID=%v]: %s", n.ID, patchj) } - } - newFull = append(newFull, changed...) - for _, n := range prev { - if !removed[n.ID] { - newFull = append(newFull, n) + if p != nil { + resp.PeersChangedPatch = append(resp.PeersChangedPatch, p) + } else { + patchifiedPeerEqual.Add(1) } + } else { + filtered = append(filtered, n) } - sortNodes(newFull) } + resp.PeersChanged = filtered + if len(resp.PeersChanged) == 0 { + resp.PeersChanged = nil + } +} - if len(mapRes.PeerSeenChange) != 0 || len(mapRes.OnlineChange) != 0 || len(mapRes.PeersChangedPatch) != 0 { - peerByID := make(map[tailcfg.NodeID]*tailcfg.Node, len(newFull)) - for _, n := range newFull { - peerByID[n.ID] = n +var nodeFields = sync.OnceValue(getNodeFields) + +// getNodeFields returns the fails of tailcfg.Node. +func getNodeFields() []string { + rt := reflect.TypeOf((*tailcfg.Node)(nil)).Elem() + ret := make([]string, rt.NumField()) + for i := 0; i < rt.NumField(); i++ { + ret[i] = rt.Field(i).Name + } + return ret +} + +// patchifyPeer returns a *tailcfg.PeerChange of the session's existing copy of +// the n.ID Node to n. +// +// It returns ok=false if a patch can't be made, (V, ok) on a delta, or (nil, +// true) if all the fields were identical (a zero change). +func (ms *mapSession) patchifyPeer(n *tailcfg.Node) (_ *tailcfg.PeerChange, ok bool) { + was, ok := ms.peers[n.ID] + if !ok { + return nil, false + } + return peerChangeDiff(*was, n) +} + +// peerChangeDiff returns the difference from 'was' to 'n', if possible. +// +// It returns (nil, true) if the fields were identical. +func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChange, ok bool) { + var ret *tailcfg.PeerChange + pc := func() *tailcfg.PeerChange { + if ret == nil { + ret = new(tailcfg.PeerChange) } - now := clockNow() - for nodeID, seen := range mapRes.PeerSeenChange { - if n, ok := peerByID[nodeID]; ok { - if seen { - n.LastSeen = &now - } else { - n.LastSeen = nil - } + return ret + } + for _, field := range nodeFields() { + switch field { + default: + // The whole point of using reflect in this function is to panic + // here in tests if we forget to handle a new field. + panic("unhandled field: " + field) + case "computedHostIfDifferent", "ComputedName", "ComputedNameWithHost": + // Caller's responsibility to have populated these. + continue + case "DataPlaneAuditLogID": + // Not sent for peers. + case "ID": + if was.ID() != n.ID { + return nil, false } - } - for nodeID, online := range mapRes.OnlineChange { - if n, ok := peerByID[nodeID]; ok { - online := online - n.Online = &online + case "StableID": + if was.StableID() != n.StableID { + return nil, false } - } - for _, ec := range mapRes.PeersChangedPatch { - if n, ok := peerByID[ec.NodeID]; ok { - if ec.DERPRegion != 0 { - n.DERP = fmt.Sprintf("%s:%v", tailcfg.DerpMagicIP, ec.DERPRegion) - } - if ec.Cap != 0 { - n.Cap = ec.Cap - } - if ec.Endpoints != nil { - n.Endpoints = ec.Endpoints - } - if ec.Key != nil { - n.Key = *ec.Key - } - if ec.DiscoKey != nil { - n.DiscoKey = *ec.DiscoKey - } - if v := ec.Online; v != nil { - n.Online = ptrCopy(v) - } - if v := ec.LastSeen; v != nil { - n.LastSeen = ptrCopy(v) - } - if v := ec.KeyExpiry; v != nil { - n.KeyExpiry = *v + case "Name": + if was.Name() != n.Name { + return nil, false + } + case "User": + if was.User() != n.User { + return nil, false + } + case "Sharer": + if was.Sharer() != n.Sharer { + return nil, false + } + case "Key": + if was.Key() != n.Key { + pc().Key = ptr.To(n.Key) + } + case "KeyExpiry": + if !was.KeyExpiry().Equal(n.KeyExpiry) { + pc().KeyExpiry = ptr.To(n.KeyExpiry) + } + case "KeySignature": + if !was.KeySignature().Equal(n.KeySignature) { + pc().KeySignature = slices.Clone(n.KeySignature) + } + case "Machine": + if was.Machine() != n.Machine { + return nil, false + } + case "DiscoKey": + if was.DiscoKey() != n.DiscoKey { + pc().DiscoKey = ptr.To(n.DiscoKey) + } + case "Addresses": + if !views.SliceEqual(was.Addresses(), views.SliceOf(n.Addresses)) { + return nil, false + } + case "AllowedIPs": + if !views.SliceEqual(was.AllowedIPs(), views.SliceOf(n.AllowedIPs)) { + return nil, false + } + case "Endpoints": + if !views.SliceEqual(was.Endpoints(), views.SliceOf(n.Endpoints)) { + pc().Endpoints = slices.Clone(n.Endpoints) + } + case "DERP": + if was.DERP() != n.DERP { + ip, portStr, err := net.SplitHostPort(n.DERP) + if err != nil || ip != "127.3.3.40" { + return nil, false } - if v := ec.Capabilities; v != nil { - n.Capabilities = *v + port, err := strconv.Atoi(portStr) + if err != nil || port < 1 || port > 65535 { + return nil, false } - if v := ec.KeySignature; v != nil { - n.KeySignature = v + pc().DERPRegion = port + } + case "Hostinfo": + if !was.Hostinfo().Valid() && !n.Hostinfo.Valid() { + continue + } + if !was.Hostinfo().Valid() || !n.Hostinfo.Valid() { + return nil, false + } + if !was.Hostinfo().Equal(n.Hostinfo) { + return nil, false + } + case "Created": + if !was.Created().Equal(n.Created) { + return nil, false + } + case "Cap": + if was.Cap() != n.Cap { + pc().Cap = n.Cap + } + case "CapMap": + if n.CapMap != nil { + pc().CapMap = n.CapMap + } + case "Tags": + if !views.SliceEqual(was.Tags(), views.SliceOf(n.Tags)) { + return nil, false + } + case "PrimaryRoutes": + if !views.SliceEqual(was.PrimaryRoutes(), views.SliceOf(n.PrimaryRoutes)) { + return nil, false + } + case "Online": + wasOnline := was.Online() + if n.Online != nil && wasOnline != nil && *n.Online != *wasOnline { + pc().Online = ptr.To(*n.Online) + } + case "LastSeen": + wasSeen := was.LastSeen() + if n.LastSeen != nil && wasSeen != nil && !wasSeen.Equal(*n.LastSeen) { + pc().LastSeen = ptr.To(*n.LastSeen) + } + case "MachineAuthorized": + if was.MachineAuthorized() != n.MachineAuthorized { + return nil, false + } + case "Capabilities": + if !views.SliceEqual(was.Capabilities(), views.SliceOf(n.Capabilities)) { + pc().Capabilities = ptr.To(n.Capabilities) + } + case "UnsignedPeerAPIOnly": + if was.UnsignedPeerAPIOnly() != n.UnsignedPeerAPIOnly { + return nil, false + } + case "IsWireGuardOnly": + if was.IsWireGuardOnly() != n.IsWireGuardOnly { + return nil, false + } + case "Expired": + if was.Expired() != n.Expired { + return nil, false + } + case "SelfNodeV4MasqAddrForThisPeer": + va, vb := was.SelfNodeV4MasqAddrForThisPeer(), n.SelfNodeV4MasqAddrForThisPeer + if va == nil && vb == nil { + continue + } + if va == nil || vb == nil || *va != *vb { + return nil, false + } + case "SelfNodeV6MasqAddrForThisPeer": + va, vb := was.SelfNodeV6MasqAddrForThisPeer(), n.SelfNodeV6MasqAddrForThisPeer + if va == nil && vb == nil { + continue + } + if va == nil || vb == nil || *va != *vb { + return nil, false + } + case "ExitNodeDNSResolvers": + va, vb := was.ExitNodeDNSResolvers(), views.SliceOfViews(n.ExitNodeDNSResolvers) + + if va.Len() != vb.Len() { + return nil, false + } + + for i := range va.LenIter() { + if !va.At(i).Equal(vb.At(i)) { + return nil, false } } + } } - - mapRes.Peers = newFull - mapRes.PeersChanged = nil - mapRes.PeersRemoved = nil + if ret != nil { + ret.NodeID = n.ID + } + return ret, true } -// ptrCopy returns a pointer to a newly allocated shallow copy of *v. -func ptrCopy[T any](v *T) *T { - if v == nil { - return nil +// netmap returns a fully populated NetworkMap from the last state seen from +// a call to updateStateFromResponse, filling in omitted +// information from prior MapResponse values. +func (ms *mapSession) netmap() *netmap.NetworkMap { + peerViews := make([]tailcfg.NodeView, len(ms.sortedPeers)) + for i, vp := range ms.sortedPeers { + peerViews[i] = *vp } - ret := new(T) - *ret = *v - return ret + + nm := &netmap.NetworkMap{ + NodeKey: ms.publicNodeKey, + PrivateKey: ms.privateNodeKey, + MachineKey: ms.machinePubKey, + Peers: peerViews, + UserProfiles: make(map[tailcfg.UserID]tailcfg.UserProfile), + Domain: ms.lastDomain, + DomainAuditLogID: ms.lastDomainAuditLogID, + DNS: *ms.lastDNSConfig, + PacketFilter: ms.lastParsedPacketFilter, + PacketFilterRules: ms.lastPacketFilterRules, + SSHPolicy: ms.lastSSHPolicy, + CollectServices: ms.collectServices, + DERPMap: ms.lastDERPMap, + ControlHealth: ms.lastHealth, + TKAEnabled: ms.lastTKAInfo != nil && !ms.lastTKAInfo.Disabled, + } + + if ms.lastTKAInfo != nil && ms.lastTKAInfo.Head != "" { + if err := nm.TKAHead.UnmarshalText([]byte(ms.lastTKAInfo.Head)); err != nil { + ms.logf("error unmarshalling TKAHead: %v", err) + nm.TKAEnabled = false + } + } + + if node := ms.lastNode; node.Valid() { + nm.SelfNode = node + nm.Expiry = node.KeyExpiry() + nm.Name = node.Name() + } + + ms.addUserProfile(nm, nm.User()) + for _, peer := range peerViews { + ms.addUserProfile(nm, peer.Sharer()) + ms.addUserProfile(nm, peer.User()) + } + if DevKnob.ForceProxyDNS() { + nm.DNS.Proxied = true + } + return nm } func nodesSorted(v []*tailcfg.Node) bool { @@ -394,18 +822,3 @@ func filterSelfAddresses(in []netip.Prefix) (ret []netip.Prefix) { return ret } } - -func copyDebugOptBools(dst, src *tailcfg.Debug) { - copy := func(v *opt.Bool, s opt.Bool) { - if s != "" { - *v = s - } - } - copy(&dst.DERPRoute, src.DERPRoute) - copy(&dst.DisableSubnetsIfPAC, src.DisableSubnetsIfPAC) - copy(&dst.DisableUPnP, src.DisableUPnP) - copy(&dst.OneCGNATRoute, src.OneCGNATRoute) - copy(&dst.SetForceBackgroundSTUN, src.SetForceBackgroundSTUN) - copy(&dst.SetRandomizeClientPort, src.SetRandomizeClientPort) - copy(&dst.TrimWGConfig, src.TrimWGConfig) -} diff --git a/vendor/tailscale.com/control/controlclient/noise.go b/vendor/tailscale.com/control/controlclient/noise.go index cad81b82ce..a9dd201800 100644 --- a/vendor/tailscale.com/control/controlclient/noise.go +++ b/vendor/tailscale.com/control/controlclient/noise.go @@ -23,6 +23,7 @@ import ( "tailscale.com/net/netmon" "tailscale.com/net/tsdial" "tailscale.com/tailcfg" + "tailscale.com/tstime" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/util/mak" @@ -287,6 +288,25 @@ func (nc *NoiseClient) GetSingleUseRoundTripper(ctx context.Context) (http.Round return nil, nil, errors.New("[unexpected] failed to reserve a request on a connection") } +// contextErr is an error that wraps another error and is used to indicate that +// the error was because a context expired. +type contextErr struct { + err error +} + +func (e contextErr) Error() string { + return e.err.Error() +} + +func (e contextErr) Unwrap() error { + return e.err +} + +// getConn returns a noiseConn that can be used to make requests to the +// coordination server. It may return a cached connection or create a new one. +// Dials are singleflighted, so concurrent calls to getConn may only dial once. +// As such, context values may not be respected as there are no guarantees that +// the context passed to getConn is the same as the context passed to dial. func (nc *NoiseClient) getConn(ctx context.Context) (*noiseConn, error) { nc.mu.Lock() if last := nc.last; last != nil && last.canTakeNewRequest() { @@ -295,11 +315,35 @@ func (nc *NoiseClient) getConn(ctx context.Context) (*noiseConn, error) { } nc.mu.Unlock() - conn, err, _ := nc.sfDial.Do(struct{}{}, nc.dial) - if err != nil { - return nil, err + for { + // We singeflight the dial to avoid making multiple connections, however + // that means that we can't simply cancel the dial if the context is + // canceled. Instead, we have to additionally check that the context + // which was canceled is our context and retry if our context is still + // valid. + conn, err, _ := nc.sfDial.Do(struct{}{}, func() (*noiseConn, error) { + c, err := nc.dial(ctx) + if err != nil { + if ctx.Err() != nil { + return nil, contextErr{ctx.Err()} + } + return nil, err + } + return c, nil + }) + var ce contextErr + if err == nil || !errors.As(err, &ce) { + return conn, err + } + if ctx.Err() == nil { + // The dial failed because of a context error, but our context + // is still valid. Retry. + continue + } + // The dial failed because our context was canceled. Return the + // underlying error. + return nil, ce.Unwrap() } - return conn, nil } func (nc *NoiseClient) RoundTrip(req *http.Request) (*http.Response, error) { @@ -344,7 +388,7 @@ func (nc *NoiseClient) Close() error { // dial opens a new connection to tailcontrol, fetching the server noise key // if not cached. -func (nc *NoiseClient) dial() (*noiseConn, error) { +func (nc *NoiseClient) dial(ctx context.Context) (*noiseConn, error) { nc.mu.Lock() connID := nc.nextID nc.nextID++ @@ -392,7 +436,7 @@ func (nc *NoiseClient) dial() (*noiseConn, error) { } timeout := time.Duration(timeoutSec * float64(time.Second)) - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() clientConn, err := (&controlhttp.Dialer{ @@ -407,6 +451,7 @@ func (nc *NoiseClient) dial() (*noiseConn, error) { DialPlan: dialPlan, Logf: nc.logf, NetMon: nc.netMon, + Clock: tstime.StdClock{}, }).Dial(ctx) if err != nil { return nil, err diff --git a/vendor/tailscale.com/control/controlclient/sign_supported.go b/vendor/tailscale.com/control/controlclient/sign_supported.go index c7525b7a7d..2dc8efa1ee 100644 --- a/vendor/tailscale.com/control/controlclient/sign_supported.go +++ b/vendor/tailscale.com/control/controlclient/sign_supported.go @@ -127,7 +127,7 @@ func findIdentity(subject string, st certstore.Store) (certstore.Identity, []*x5 return nil, nil, err } - selected, chain := selectIdentityFromSlice(subject, ids, time.Now()) + selected, chain := selectIdentityFromSlice(subject, ids, clock.Now()) for _, id := range ids { if id != selected { diff --git a/vendor/tailscale.com/control/controlclient/status.go b/vendor/tailscale.com/control/controlclient/status.go index 294c5dbb2f..d0fdf80d74 100644 --- a/vendor/tailscale.com/control/controlclient/status.go +++ b/vendor/tailscale.com/control/controlclient/status.go @@ -8,7 +8,6 @@ import ( "fmt" "reflect" - "tailscale.com/types/empty" "tailscale.com/types/netmap" "tailscale.com/types/persist" "tailscale.com/types/structs" @@ -38,6 +37,10 @@ const ( StateSynchronized // connected and received map update ) +func (s State) AppendText(b []byte) ([]byte, error) { + return append(b, s.String()...), nil +} + func (s State) MarshalText() ([]byte, error) { return []byte(s.String()), nil } @@ -62,34 +65,55 @@ func (s State) String() string { } type Status struct { - _ structs.Incomparable - LoginFinished *empty.Message // nonempty when login finishes - LogoutFinished *empty.Message // nonempty when logout finishes - Err error - URL string // interactive URL to visit to finish logging in - NetMap *netmap.NetworkMap // server-pushed configuration - - // The internal state should not be exposed outside this + _ structs.Incomparable + + // Err, if non-nil, is an error that occurred while logging in. + // + // If it's of type UserVisibleError then it's meant to be shown to users in + // their Tailscale client. Otherwise it's just logged to tailscaled's logs. + Err error + + // URL, if non-empty, is the interactive URL to visit to finish logging in. + URL string + + // NetMap is the latest server-pushed state of the tailnet network. + NetMap *netmap.NetworkMap + + // Persist, when Valid, is the locally persisted configuration. + // + // TODO(bradfitz,maisem): clarify this. + Persist persist.PersistView + + // state is the internal state. It should not be exposed outside this // package, but we have some automated tests elsewhere that need to - // use them. Please don't use these fields. + // use it via the StateForTest accessor. // TODO(apenwarr): Unexport or remove these. - State State - Persist *persist.PersistView // locally persisted configuration + state State } +// LoginFinished reports whether the controlclient is in its "StateAuthenticated" +// state where it's in a happy register state but not yet in a map poll. +// +// TODO(bradfitz): delete this and everything around Status.state. +func (s *Status) LoginFinished() bool { return s.state == StateAuthenticated } + +// StateForTest returns the internal state of s for tests only. +func (s *Status) StateForTest() State { return s.state } + +// SetStateForTest sets the internal state of s for tests only. +func (s *Status) SetStateForTest(state State) { s.state = state } + // Equal reports whether s and s2 are equal. func (s *Status) Equal(s2 *Status) bool { if s == nil && s2 == nil { return true } return s != nil && s2 != nil && - (s.LoginFinished == nil) == (s2.LoginFinished == nil) && - (s.LogoutFinished == nil) == (s2.LogoutFinished == nil) && s.Err == s2.Err && s.URL == s2.URL && + s.state == s2.state && reflect.DeepEqual(s.Persist, s2.Persist) && - reflect.DeepEqual(s.NetMap, s2.NetMap) && - s.State == s2.State + reflect.DeepEqual(s.NetMap, s2.NetMap) } func (s Status) String() string { @@ -97,5 +121,5 @@ func (s Status) String() string { if err != nil { panic(err) } - return s.State.String() + " " + string(b) + return s.state.String() + " " + string(b) } diff --git a/vendor/tailscale.com/control/controlhttp/client.go b/vendor/tailscale.com/control/controlhttp/client.go index b0d91bada8..fb220fd0b0 100644 --- a/vendor/tailscale.com/control/controlhttp/client.go +++ b/vendor/tailscale.com/control/controlhttp/client.go @@ -45,6 +45,7 @@ import ( "tailscale.com/net/tlsdial" "tailscale.com/net/tshttpproxy" "tailscale.com/tailcfg" + "tailscale.com/tstime" "tailscale.com/util/multierr" ) @@ -147,13 +148,16 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) { // before we do anything. if c.DialStartDelaySec > 0 { a.logf("[v2] controlhttp: waiting %.2f seconds before dialing %q @ %v", c.DialStartDelaySec, a.Hostname, c.IP) - tmr := time.NewTimer(time.Duration(c.DialStartDelaySec * float64(time.Second))) + if a.Clock == nil { + a.Clock = tstime.StdClock{} + } + tmr, tmrChannel := a.Clock.NewTimer(time.Duration(c.DialStartDelaySec * float64(time.Second))) defer tmr.Stop() select { case <-ctx.Done(): err = ctx.Err() return - case <-tmr.C: + case <-tmrChannel: } } @@ -319,7 +323,10 @@ func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*ClientConn, er // In case outbound port 80 blocked or MITM'ed poorly, start a backup timer // to dial port 443 if port 80 doesn't either succeed or fail quickly. - try443Timer := time.AfterFunc(a.httpsFallbackDelay(), func() { try(u443) }) + if a.Clock == nil { + a.Clock = tstime.StdClock{} + } + try443Timer := a.Clock.AfterFunc(a.httpsFallbackDelay(), func() { try(u443) }) defer try443Timer.Stop() var err80, err443 error diff --git a/vendor/tailscale.com/control/controlhttp/client_js.go b/vendor/tailscale.com/control/controlhttp/client_js.go index 5a4b4d08b1..7ad5963660 100644 --- a/vendor/tailscale.com/control/controlhttp/client_js.go +++ b/vendor/tailscale.com/control/controlhttp/client_js.go @@ -51,7 +51,7 @@ func (d *Dialer) Dial(ctx context.Context) (*ClientConn, error) { if err != nil { return nil, err } - netConn := wsconn.NetConn(context.Background(), wsConn, websocket.MessageBinary) + netConn := wsconn.NetConn(context.Background(), wsConn, websocket.MessageBinary, wsURL.String()) cbConn, err := cont(ctx, netConn) if err != nil { netConn.Close() diff --git a/vendor/tailscale.com/control/controlhttp/constants.go b/vendor/tailscale.com/control/controlhttp/constants.go index b838f84c43..72161336e3 100644 --- a/vendor/tailscale.com/control/controlhttp/constants.go +++ b/vendor/tailscale.com/control/controlhttp/constants.go @@ -11,6 +11,7 @@ import ( "tailscale.com/net/dnscache" "tailscale.com/net/netmon" "tailscale.com/tailcfg" + "tailscale.com/tstime" "tailscale.com/types/key" "tailscale.com/types/logger" ) @@ -89,6 +90,10 @@ type Dialer struct { drainFinished chan struct{} omitCertErrorLogging bool testFallbackDelay time.Duration + + // tstime.Clock is used instead of time package for methods such as time.Now. + // If not specified, will default to tstime.StdClock{}. + Clock tstime.Clock } func strDef(v1, v2 string) string { diff --git a/vendor/tailscale.com/control/controlhttp/server.go b/vendor/tailscale.com/control/controlhttp/server.go index d49e32c1da..ee469fabda 100644 --- a/vendor/tailscale.com/control/controlhttp/server.go +++ b/vendor/tailscale.com/control/controlhttp/server.go @@ -146,7 +146,7 @@ func acceptWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request return nil, fmt.Errorf("decoding base64 handshake parameter: %v", err) } - conn := wsconn.NetConn(ctx, c, websocket.MessageBinary) + conn := wsconn.NetConn(ctx, c, websocket.MessageBinary, r.RemoteAddr) nc, err := controlbase.Server(ctx, conn, private, init) if err != nil { conn.Close() diff --git a/vendor/tailscale.com/control/controlknobs/controlknobs.go b/vendor/tailscale.com/control/controlknobs/controlknobs.go index 65492b39e0..3ea0575a57 100644 --- a/vendor/tailscale.com/control/controlknobs/controlknobs.go +++ b/vendor/tailscale.com/control/controlknobs/controlknobs.go @@ -6,24 +6,101 @@ package controlknobs import ( + "slices" "sync/atomic" - "tailscale.com/envknob" + "tailscale.com/syncs" + "tailscale.com/tailcfg" + "tailscale.com/types/opt" ) -// disableUPnP indicates whether to attempt UPnP mapping. -var disableUPnPControl atomic.Bool +// Knobs is the set of knobs that the control plane's coordination server can +// adjust at runtime. +type Knobs struct { + // DisableUPnP indicates whether to attempt UPnP mapping. + DisableUPnP atomic.Bool -var disableUPnpEnv = envknob.RegisterBool("TS_DISABLE_UPNP") + // DisableDRPO is whether control says to disable the + // DERP route optimization (Issue 150). + DisableDRPO atomic.Bool -// DisableUPnP reports the last reported value from control -// whether UPnP portmapping should be disabled. -func DisableUPnP() bool { - return disableUPnPControl.Load() || disableUPnpEnv() + // KeepFullWGConfig is whether we should disable the lazy wireguard + // programming and instead give WireGuard the full netmap always, even for + // idle peers. + KeepFullWGConfig atomic.Bool + + // RandomizeClientPort is whether control says we should randomize + // the client port. + RandomizeClientPort atomic.Bool + + // OneCGNAT is whether the the node should make one big CGNAT route + // in the OS rather than one /32 per peer. + OneCGNAT syncs.AtomicValue[opt.Bool] + + // ForceBackgroundSTUN forces netcheck STUN queries to keep + // running in magicsock, even when idle. + ForceBackgroundSTUN atomic.Bool + + // DisableDeltaUpdates is whether the node should not process + // incremental (delta) netmap updates and should treat all netmap + // changes as "full" ones as tailscaled did in 1.48.x and earlier. + DisableDeltaUpdates atomic.Bool + + // PeerMTUEnable is whether the node should do peer path MTU discovery. + PeerMTUEnable atomic.Bool +} + +// UpdateFromNodeAttributes updates k (if non-nil) based on the provided self +// node attributes (Node.Capabilities). +func (k *Knobs) UpdateFromNodeAttributes(selfNodeAttrs []tailcfg.NodeCapability, capMap tailcfg.NodeCapMap) { + if k == nil { + return + } + has := func(attr tailcfg.NodeCapability) bool { + _, ok := capMap[attr] + return ok || slices.Contains(selfNodeAttrs, attr) + } + var ( + keepFullWG = has(tailcfg.NodeAttrDebugDisableWGTrim) + disableDRPO = has(tailcfg.NodeAttrDebugDisableDRPO) + disableUPnP = has(tailcfg.NodeAttrDisableUPnP) + randomizeClientPort = has(tailcfg.NodeAttrRandomizeClientPort) + disableDeltaUpdates = has(tailcfg.NodeAttrDisableDeltaUpdates) + oneCGNAT opt.Bool + forceBackgroundSTUN = has(tailcfg.NodeAttrDebugForceBackgroundSTUN) + peerMTUEnable = has(tailcfg.NodeAttrPeerMTUEnable) + ) + + if has(tailcfg.NodeAttrOneCGNATEnable) { + oneCGNAT.Set(true) + } else if has(tailcfg.NodeAttrOneCGNATDisable) { + oneCGNAT.Set(false) + } + + k.KeepFullWGConfig.Store(keepFullWG) + k.DisableDRPO.Store(disableDRPO) + k.DisableUPnP.Store(disableUPnP) + k.RandomizeClientPort.Store(randomizeClientPort) + k.OneCGNAT.Store(oneCGNAT) + k.ForceBackgroundSTUN.Store(forceBackgroundSTUN) + k.DisableDeltaUpdates.Store(disableDeltaUpdates) + k.PeerMTUEnable.Store(peerMTUEnable) } -// SetDisableUPnP sets whether control says that UPnP should be -// disabled. -func SetDisableUPnP(v bool) { - disableUPnPControl.Store(v) +// AsDebugJSON returns k as something that can be marshalled with json.Marshal +// for debug. +func (k *Knobs) AsDebugJSON() map[string]any { + if k == nil { + return nil + } + return map[string]any{ + "DisableUPnP": k.DisableUPnP.Load(), + "DisableDRPO": k.DisableDRPO.Load(), + "KeepFullWGConfig": k.KeepFullWGConfig.Load(), + "RandomizeClientPort": k.RandomizeClientPort.Load(), + "OneCGNAT": k.OneCGNAT.Load(), + "ForceBackgroundSTUN": k.ForceBackgroundSTUN.Load(), + "DisableDeltaUpdates": k.DisableDeltaUpdates.Load(), + "PeerMTUEnable": k.PeerMTUEnable.Load(), + } } diff --git a/vendor/tailscale.com/derp/derp.go b/vendor/tailscale.com/derp/derp.go index c1ae5b5937..63af44585e 100644 --- a/vendor/tailscale.com/derp/derp.go +++ b/vendor/tailscale.com/derp/derp.go @@ -85,7 +85,7 @@ const ( // framePeerPresent is like framePeerGone, but for other // members of the DERP region when they're meshed up together. - framePeerPresent = frameType(0x09) // 32B pub key of peer that's connected + framePeerPresent = frameType(0x09) // 32B pub key of peer that's connected + optional 18B ip:port (16 byte IP + 2 byte BE uint16 port) // frameWatchConns is how one DERP node in a regional mesh // subscribes to the others in the region. @@ -199,7 +199,7 @@ func readFrame(br *bufio.Reader, maxSize uint32, b []byte) (t frameType, frameLe return 0, 0, fmt.Errorf("frame header size %d exceeds reader limit of %d", frameLen, maxSize) } - n, err := io.ReadFull(br, b[:minUint32(frameLen, uint32(len(b)))]) + n, err := io.ReadFull(br, b[:min(frameLen, uint32(len(b)))]) if err != nil { return 0, 0, err } @@ -233,10 +233,3 @@ func writeFrame(bw *bufio.Writer, t frameType, b []byte) error { } return bw.Flush() } - -func minUint32(a, b uint32) uint32 { - if a < b { - return a - } - return b -} diff --git a/vendor/tailscale.com/derp/derp_client.go b/vendor/tailscale.com/derp/derp_client.go index 2889d81abf..7ad98cfe81 100644 --- a/vendor/tailscale.com/derp/derp_client.go +++ b/vendor/tailscale.com/derp/derp_client.go @@ -17,6 +17,7 @@ import ( "go4.org/mem" "golang.org/x/time/rate" "tailscale.com/syncs" + "tailscale.com/tstime" "tailscale.com/types/key" "tailscale.com/types/logger" ) @@ -40,6 +41,8 @@ type Client struct { // Owned by Recv: peeked int // bytes to discard on next Recv readErr syncs.AtomicValue[error] // sticky (set by Recv) + + clock tstime.Clock } // ClientOpt is an option passed to NewClient. @@ -103,6 +106,7 @@ func newClient(privateKey key.NodePrivate, nc Conn, brw *bufio.ReadWriter, logf meshKey: opt.MeshKey, canAckPings: opt.CanAckPings, isProber: opt.IsProber, + clock: tstime.StdClock{}, } if opt.ServerPub.IsZero() { if err := c.recvServerKey(); err != nil { @@ -214,7 +218,7 @@ func (c *Client) send(dstKey key.NodePublic, pkt []byte) (ret error) { defer c.wmu.Unlock() if c.rate != nil { pktLen := frameHeaderLen + key.NodePublicRawLen + len(pkt) - if !c.rate.AllowN(time.Now(), pktLen) { + if !c.rate.AllowN(c.clock.Now(), pktLen) { return nil // drop } } @@ -244,7 +248,7 @@ func (c *Client) ForwardPacket(srcKey, dstKey key.NodePublic, pkt []byte) (err e c.wmu.Lock() defer c.wmu.Unlock() - timer := time.AfterFunc(5*time.Second, c.writeTimeoutFired) + timer := c.clock.AfterFunc(5*time.Second, c.writeTimeoutFired) defer timer.Stop() if err := writeFrameHeader(c.bw, frameForwardPacket, uint32(keyLen*2+len(pkt))); err != nil { @@ -359,7 +363,12 @@ func (PeerGoneMessage) msg() {} // PeerPresentMessage is a ReceivedMessage that indicates that the client // is connected to the server. (Only used by trusted mesh clients) -type PeerPresentMessage key.NodePublic +type PeerPresentMessage struct { + // Key is the public key of the client. + Key key.NodePublic + // IPPort is the remote IP and port of the client. + IPPort netip.AddrPort +} func (PeerPresentMessage) msg() {} @@ -457,7 +466,6 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro c.readErr.Store(err) } }() - for { c.nc.SetReadDeadline(time.Now().Add(timeout)) @@ -543,8 +551,15 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro c.logf("[unexpected] dropping short peerPresent frame from DERP server") continue } - pg := PeerPresentMessage(key.NodePublicFromRaw32(mem.B(b[:keyLen]))) - return pg, nil + var msg PeerPresentMessage + msg.Key = key.NodePublicFromRaw32(mem.B(b[:keyLen])) + if n >= keyLen+16+2 { + msg.IPPort = netip.AddrPortFrom( + netip.AddrFrom16([16]byte(b[keyLen:keyLen+16])).Unmap(), + binary.BigEndian.Uint16(b[keyLen+16:keyLen+16+2]), + ) + } + return msg, nil case frameRecvPacket: var rp ReceivedPacket diff --git a/vendor/tailscale.com/derp/derp_server.go b/vendor/tailscale.com/derp/derp_server.go index 1ad5d25f3e..cf42acdf79 100644 --- a/vendor/tailscale.com/derp/derp_server.go +++ b/vendor/tailscale.com/derp/derp_server.go @@ -12,6 +12,7 @@ import ( crand "crypto/rand" "crypto/x509" "crypto/x509/pkix" + "encoding/binary" "encoding/json" "errors" "expvar" @@ -39,9 +40,11 @@ import ( "tailscale.com/envknob" "tailscale.com/metrics" "tailscale.com/syncs" + "tailscale.com/tstime" "tailscale.com/tstime/rate" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/util/set" "tailscale.com/version" ) @@ -149,7 +152,7 @@ type Server struct { closed bool netConns map[Conn]chan struct{} // chan is closed when conn closes clients map[key.NodePublic]clientSet - watchers map[*sclient]bool // mesh peer -> true + watchers set.Set[*sclient] // mesh peers // clientsMesh tracks all clients in the cluster, both locally // and to mesh peers. If the value is nil, that means the // peer is only local (and thus in the clients Map, but not @@ -164,6 +167,8 @@ type Server struct { // maps from netip.AddrPort to a client's public key keyOfAddr map[netip.AddrPort]key.NodePublic + + clock tstime.Clock } // clientSet represents 1 or more *sclients. @@ -216,8 +221,7 @@ func (s singleClient) ForeachClient(f func(*sclient)) { f(s.c) } // All fields are guarded by Server.mu. type dupClientSet struct { // set is the set of connected clients for sclient.key. - // The values are all true. - set map[*sclient]bool + set set.Set[*sclient] // last is the most recent addition to set, or nil if the most // recent one has since disconnected and nobody else has send @@ -258,7 +262,7 @@ func (s *dupClientSet) removeClient(c *sclient) bool { trim := s.sendHistory[:0] for _, v := range s.sendHistory { - if s.set[v] && (len(trim) == 0 || trim[len(trim)-1] != v) { + if s.set.Contains(v) && (len(trim) == 0 || trim[len(trim)-1] != v) { trim = append(trim, v) } } @@ -313,11 +317,12 @@ func NewServer(privateKey key.NodePrivate, logf logger.Logf) *Server { clientsMesh: map[key.NodePublic]PacketForwarder{}, netConns: map[Conn]chan struct{}{}, memSys0: ms.Sys, - watchers: map[*sclient]bool{}, + watchers: set.Set[*sclient]{}, sentTo: map[key.NodePublic]map[key.NodePublic]int64{}, avgQueueDuration: new(uint64), tcpRtt: metrics.LabelMap{Label: "le"}, keyOfAddr: map[netip.AddrPort]key.NodePublic{}, + clock: tstime.StdClock{}, } s.initMetacert() s.packetsRecvDisco = s.packetsRecvByKind.Get("disco") @@ -467,8 +472,8 @@ func (s *Server) initMetacert() { CommonName: fmt.Sprintf("derpkey%s", s.publicKey.UntypedHexString()), }, // Windows requires NotAfter and NotBefore set: - NotAfter: time.Now().Add(30 * 24 * time.Hour), - NotBefore: time.Now().Add(-30 * 24 * time.Hour), + NotAfter: s.clock.Now().Add(30 * 24 * time.Hour), + NotBefore: s.clock.Now().Add(-30 * 24 * time.Hour), // Per https://github.com/golang/go/issues/51759#issuecomment-1071147836, // macOS requires BasicConstraints when subject == issuer: BasicConstraintsValid: true, @@ -494,8 +499,8 @@ func (s *Server) registerClient(c *sclient) { s.mu.Lock() defer s.mu.Unlock() - set := s.clients[c.key] - switch set := set.(type) { + curSet := s.clients[c.key] + switch curSet := curSet.(type) { case nil: s.clients[c.key] = singleClient{c} c.debugLogf("register single client") @@ -503,14 +508,14 @@ func (s *Server) registerClient(c *sclient) { s.dupClientKeys.Add(1) s.dupClientConns.Add(2) // both old and new count s.dupClientConnTotal.Add(1) - old := set.ActiveClient() + old := curSet.ActiveClient() old.isDup.Store(true) c.isDup.Store(true) s.clients[c.key] = &dupClientSet{ last: c, - set: map[*sclient]bool{ - old: true, - c: true, + set: set.Set[*sclient]{ + old: struct{}{}, + c: struct{}{}, }, sendHistory: []*sclient{old}, } @@ -519,9 +524,9 @@ func (s *Server) registerClient(c *sclient) { s.dupClientConns.Add(1) // the gauge s.dupClientConnTotal.Add(1) // the counter c.isDup.Store(true) - set.set[c] = true - set.last = c - set.sendHistory = append(set.sendHistory, c) + curSet.set.Add(c) + curSet.last = c + curSet.sendHistory = append(curSet.sendHistory, c) c.debugLogf("register another duplicate client") } @@ -530,7 +535,7 @@ func (s *Server) registerClient(c *sclient) { } s.keyOfAddr[c.remoteIPPort] = c.key s.curClients.Add(1) - s.broadcastPeerStateChangeLocked(c.key, true) + s.broadcastPeerStateChangeLocked(c.key, c.remoteIPPort, true) } // broadcastPeerStateChangeLocked enqueues a message to all watchers @@ -538,9 +543,13 @@ func (s *Server) registerClient(c *sclient) { // presence changed. // // s.mu must be held. -func (s *Server) broadcastPeerStateChangeLocked(peer key.NodePublic, present bool) { +func (s *Server) broadcastPeerStateChangeLocked(peer key.NodePublic, ipPort netip.AddrPort, present bool) { for w := range s.watchers { - w.peerStateChange = append(w.peerStateChange, peerConnState{peer: peer, present: present}) + w.peerStateChange = append(w.peerStateChange, peerConnState{ + peer: peer, + present: present, + ipPort: ipPort, + }) go w.requestMeshUpdate() } } @@ -561,7 +570,7 @@ func (s *Server) unregisterClient(c *sclient) { delete(s.clientsMesh, c.key) s.notePeerGoneFromRegionLocked(c.key) } - s.broadcastPeerStateChangeLocked(c.key, false) + s.broadcastPeerStateChangeLocked(c.key, netip.AddrPort{}, false) case *dupClientSet: c.debugLogf("removed duplicate client") if set.removeClient(c) { @@ -651,13 +660,21 @@ func (s *Server) addWatcher(c *sclient) { defer s.mu.Unlock() // Queue messages for each already-connected client. - for peer := range s.clients { - c.peerStateChange = append(c.peerStateChange, peerConnState{peer: peer, present: true}) + for peer, clientSet := range s.clients { + ac := clientSet.ActiveClient() + if ac == nil { + continue + } + c.peerStateChange = append(c.peerStateChange, peerConnState{ + peer: peer, + present: true, + ipPort: ac.remoteIPPort, + }) } // And enroll the watcher in future updates (of both // connections & disconnections). - s.watchers[c] = true + s.watchers.Add(c) go c.requestMeshUpdate() } @@ -697,7 +714,7 @@ func (s *Server) accept(ctx context.Context, nc Conn, brw *bufio.ReadWriter, rem done: ctx.Done(), remoteAddr: remoteAddr, remoteIPPort: remoteIPPort, - connectedAt: time.Now(), + connectedAt: s.clock.Now(), sendQueue: make(chan pkt, perClientSendQueueDepth), discoSendQueue: make(chan pkt, perClientSendQueueDepth), sendPongCh: make(chan [8]byte, 1), @@ -927,7 +944,7 @@ func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error { return c.sendPkt(dst, pkt{ bs: contents, - enqueuedAt: time.Now(), + enqueuedAt: c.s.clock.Now(), src: srcKey, }) } @@ -994,7 +1011,7 @@ func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error { p := pkt{ bs: contents, - enqueuedAt: time.Now(), + enqueuedAt: c.s.clock.Now(), src: c.key, } return c.sendPkt(dst, p) @@ -1345,6 +1362,7 @@ type sclient struct { type peerConnState struct { peer key.NodePublic present bool + ipPort netip.AddrPort // if present, the peer's IP:port } // pkt is a request to write a data frame to an sclient. @@ -1387,7 +1405,7 @@ func (c *sclient) setPreferred(v bool) { // graphs, so not important to miss a move. But it shouldn't: // the netcheck/re-STUNs in magicsock only happen about every // 30 seconds. - if time.Since(c.connectedAt) > 5*time.Second { + if c.s.clock.Since(c.connectedAt) > 5*time.Second { homeMove.Add(1) } } @@ -1401,7 +1419,7 @@ func expMovingAverage(prev, newValue, alpha float64) float64 { // recordQueueTime updates the average queue duration metric after a packet has been sent. func (c *sclient) recordQueueTime(enqueuedAt time.Time) { - elapsed := float64(time.Since(enqueuedAt).Milliseconds()) + elapsed := float64(c.s.clock.Since(enqueuedAt).Milliseconds()) for { old := atomic.LoadUint64(c.s.avgQueueDuration) newAvg := expMovingAverage(math.Float64frombits(old), elapsed, 0.1) @@ -1431,7 +1449,7 @@ func (c *sclient) sendLoop(ctx context.Context) error { }() jitter := time.Duration(rand.Intn(5000)) * time.Millisecond - keepAliveTick := time.NewTicker(keepAlive + jitter) + keepAliveTick, keepAliveTickChannel := c.s.clock.NewTicker(keepAlive + jitter) defer keepAliveTick.Stop() var werr error // last write error @@ -1461,7 +1479,7 @@ func (c *sclient) sendLoop(ctx context.Context) error { case msg := <-c.sendPongCh: werr = c.sendPong(msg) continue - case <-keepAliveTick.C: + case <-keepAliveTickChannel: werr = c.sendKeepAlive() continue default: @@ -1490,7 +1508,7 @@ func (c *sclient) sendLoop(ctx context.Context) error { case msg := <-c.sendPongCh: werr = c.sendPong(msg) continue - case <-keepAliveTick.C: + case <-keepAliveTickChannel: werr = c.sendKeepAlive() } } @@ -1538,12 +1556,18 @@ func (c *sclient) sendPeerGone(peer key.NodePublic, reason PeerGoneReasonType) e } // sendPeerPresent sends a peerPresent frame, without flushing. -func (c *sclient) sendPeerPresent(peer key.NodePublic) error { +func (c *sclient) sendPeerPresent(peer key.NodePublic, ipPort netip.AddrPort) error { c.setWriteDeadline() - if err := writeFrameHeader(c.bw.bw(), framePeerPresent, keyLen); err != nil { + const frameLen = keyLen + 16 + 2 + if err := writeFrameHeader(c.bw.bw(), framePeerPresent, frameLen); err != nil { return err } - _, err := c.bw.Write(peer.AppendTo(nil)) + payload := make([]byte, frameLen) + _ = peer.AppendTo(payload[:0]) + a16 := ipPort.Addr().As16() + copy(payload[keyLen:], a16[:]) + binary.BigEndian.PutUint16(payload[keyLen+16:], ipPort.Port()) + _, err := c.bw.Write(payload) return err } @@ -1562,7 +1586,7 @@ func (c *sclient) sendMeshUpdates() error { } var err error if pcs.present { - err = c.sendPeerPresent(pcs.peer) + err = c.sendPeerPresent(pcs.peer, pcs.ipPort) } else { err = c.sendPeerGone(pcs.peer, PeerGoneReasonDisconnected) } diff --git a/vendor/tailscale.com/derp/derp_server_linux.go b/vendor/tailscale.com/derp/derp_server_linux.go index 1029226a6f..48da8ed30a 100644 --- a/vendor/tailscale.com/derp/derp_server_linux.go +++ b/vendor/tailscale.com/derp/derp_server_linux.go @@ -9,45 +9,37 @@ import ( "net" "time" - "golang.org/x/sys/unix" + "tailscale.com/net/tcpinfo" ) func (c *sclient) statsLoop(ctx context.Context) error { - // If we can't get a TCP socket, then we can't send stats. - tcpConn := c.tcpConn() - if tcpConn == nil { + // Get the RTT initially to verify it's supported. + conn := c.tcpConn() + if conn == nil { c.s.tcpRtt.Add("non-tcp", 1) return nil } - rawConn, err := tcpConn.SyscallConn() - if err != nil { - c.logf("error getting SyscallConn: %v", err) + if _, err := tcpinfo.RTT(conn); err != nil { + c.logf("error fetching initial RTT: %v", err) c.s.tcpRtt.Add("error", 1) return nil } const statsInterval = 10 * time.Second - ticker := time.NewTicker(statsInterval) + ticker, tickerChannel := c.s.clock.NewTicker(statsInterval) defer ticker.Stop() - var ( - tcpInfo *unix.TCPInfo - sysErr error - ) statsLoop: for { select { - case <-ticker.C: - err = rawConn.Control(func(fd uintptr) { - tcpInfo, sysErr = unix.GetsockoptTCPInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_INFO) - }) - if err != nil || sysErr != nil { + case <-tickerChannel: + rtt, err := tcpinfo.RTT(conn) + if err != nil { continue statsLoop } // TODO(andrew): more metrics? - rtt := time.Duration(tcpInfo.Rtt) * time.Microsecond c.s.tcpRtt.Add(durationToLabel(rtt), 1) case <-ctx.Done(): diff --git a/vendor/tailscale.com/derp/derphttp/derphttp_client.go b/vendor/tailscale.com/derp/derphttp/derphttp_client.go index 07317fcbf3..3bd3144648 100644 --- a/vendor/tailscale.com/derp/derphttp/derphttp_client.go +++ b/vendor/tailscale.com/derp/derphttp/derphttp_client.go @@ -38,6 +38,7 @@ import ( "tailscale.com/net/tshttpproxy" "tailscale.com/syncs" "tailscale.com/tailcfg" + "tailscale.com/tstime" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/util/cmpx" @@ -55,6 +56,11 @@ type Client struct { MeshKey string // optional; for trusted clients IsProber bool // optional; for probers to optional declare themselves as such + // BaseContext, if non-nil, returns the base context to use for dialing a + // new derp server. If nil, context.Background is used. + // In either case, additional timeouts may be added to the base context. + BaseContext func() context.Context + privateKey key.NodePrivate logf logger.Logf netMon *netmon.Monitor // optional; nil means interfaces will be looked up on-demand @@ -83,6 +89,7 @@ type Client struct { serverPubKey key.NodePublic tlsState *tls.ConnectionState pingOut map[derp.PingMessage]chan<- bool // chan to send to on pong + clock tstime.Clock } func (c *Client) String() string { @@ -101,6 +108,7 @@ func NewRegionClient(privateKey key.NodePrivate, logf logger.Logf, netMon *netmo getRegion: getRegion, ctx: ctx, cancelCtx: cancel, + clock: tstime.StdClock{}, } return c } @@ -108,7 +116,7 @@ func NewRegionClient(privateKey key.NodePrivate, logf logger.Logf, netMon *netmo // NewNetcheckClient returns a Client that's only able to have its DialRegionTLS method called. // It's used by the netcheck package. func NewNetcheckClient(logf logger.Logf) *Client { - return &Client{logf: logf} + return &Client{logf: logf, clock: tstime.StdClock{}} } // NewClient returns a new DERP-over-HTTP client. It connects lazily. @@ -129,6 +137,7 @@ func NewClient(privateKey key.NodePrivate, serverURL string, logf logger.Logf) ( url: u, ctx: ctx, cancelCtx: cancel, + clock: tstime.StdClock{}, } return c, nil } @@ -140,6 +149,19 @@ func (c *Client) Connect(ctx context.Context) error { return err } +// newContext returns a new context for setting up a new DERP connection. +// It uses either c.BaseContext or returns context.Background. +func (c *Client) newContext() context.Context { + if c.BaseContext != nil { + ctx := c.BaseContext() + if ctx == nil { + panic("BaseContext returned nil") + } + return ctx + } + return context.Background() +} + // TLSConnectionState returns the last TLS connection state, if any. // The client must already be connected. func (c *Client) TLSConnectionState() (_ *tls.ConnectionState, ok bool) { @@ -644,14 +666,14 @@ func (c *Client) dialNode(ctx context.Context, n *tailcfg.DERPNode) (net.Conn, e nwait++ go func() { if proto == "tcp4" && c.preferIPv6() { - t := time.NewTimer(200 * time.Millisecond) + t, tChannel := c.clock.NewTimer(200 * time.Millisecond) select { case <-ctx.Done(): // Either user canceled original context, // it timed out, or the v6 dial succeeded. t.Stop() return - case <-t.C: + case <-tChannel: // Start v4 dial } } @@ -708,8 +730,9 @@ func firstStr(a, b string) string { } // dialNodeUsingProxy connects to n using a CONNECT to the HTTP(s) proxy in proxyURL. -func (c *Client) dialNodeUsingProxy(ctx context.Context, n *tailcfg.DERPNode, proxyURL *url.URL) (proxyConn net.Conn, err error) { +func (c *Client) dialNodeUsingProxy(ctx context.Context, n *tailcfg.DERPNode, proxyURL *url.URL) (_ net.Conn, err error) { pu := proxyURL + var proxyConn net.Conn if pu.Scheme == "https" { var d tls.Dialer proxyConn, err = d.DialContext(ctx, "tcp", net.JoinHostPort(pu.Hostname(), firstStr(pu.Port(), "443"))) @@ -772,7 +795,7 @@ func (c *Client) dialNodeUsingProxy(ctx context.Context, n *tailcfg.DERPNode, pr } func (c *Client) Send(dstKey key.NodePublic, b []byte) error { - client, _, err := c.connect(context.TODO(), "derphttp.Client.Send") + client, _, err := c.connect(c.newContext(), "derphttp.Client.Send") if err != nil { return err } @@ -872,7 +895,7 @@ func (c *Client) LocalAddr() (netip.AddrPort, error) { } func (c *Client) ForwardPacket(from, to key.NodePublic, b []byte) error { - client, _, err := c.connect(context.TODO(), "derphttp.Client.ForwardPacket") + client, _, err := c.connect(c.newContext(), "derphttp.Client.ForwardPacket") if err != nil { return err } @@ -938,7 +961,7 @@ func (c *Client) NotePreferred(v bool) { // // Only trusted connections (using MeshKey) are allowed to use this. func (c *Client) WatchConnectionChanges() error { - client, _, err := c.connect(context.TODO(), "derphttp.Client.WatchConnectionChanges") + client, _, err := c.connect(c.newContext(), "derphttp.Client.WatchConnectionChanges") if err != nil { return err } @@ -953,7 +976,7 @@ func (c *Client) WatchConnectionChanges() error { // // Only trusted connections (using MeshKey) are allowed to use this. func (c *Client) ClosePeer(target key.NodePublic) error { - client, _, err := c.connect(context.TODO(), "derphttp.Client.ClosePeer") + client, _, err := c.connect(c.newContext(), "derphttp.Client.ClosePeer") if err != nil { return err } @@ -974,7 +997,7 @@ func (c *Client) Recv() (derp.ReceivedMessage, error) { // RecvDetail is like Recv, but additional returns the connection generation on each message. // The connGen value is incremented every time the derphttp.Client reconnects to the server. func (c *Client) RecvDetail() (m derp.ReceivedMessage, connGen int, err error) { - client, connGen, err := c.connect(context.TODO(), "derphttp.Client.Recv") + client, connGen, err := c.connect(c.newContext(), "derphttp.Client.Recv") if err != nil { return nil, 0, err } diff --git a/vendor/tailscale.com/derp/derphttp/mesh_client.go b/vendor/tailscale.com/derp/derphttp/mesh_client.go index 4454136ab1..748598d6fe 100644 --- a/vendor/tailscale.com/derp/derphttp/mesh_client.go +++ b/vendor/tailscale.com/derp/derphttp/mesh_client.go @@ -5,6 +5,7 @@ package derphttp import ( "context" + "net/netip" "sync" "time" @@ -26,7 +27,7 @@ import ( // // To force RunWatchConnectionLoop to return quickly, its ctx needs to // be closed, and c itself needs to be closed. -func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key.NodePublic, infoLogf logger.Logf, add, remove func(key.NodePublic)) { +func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key.NodePublic, infoLogf logger.Logf, add func(key.NodePublic, netip.AddrPort), remove func(key.NodePublic)) { if infoLogf == nil { infoLogf = logger.Discard } @@ -51,7 +52,7 @@ func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key present = map[key.NodePublic]bool{} } lastConnGen := 0 - lastStatus := time.Now() + lastStatus := c.clock.Now() logConnectedLocked := func() { if loggedConnected { return @@ -61,16 +62,16 @@ func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key } const logConnectedDelay = 200 * time.Millisecond - timer := time.AfterFunc(2*time.Second, func() { + timer := c.clock.AfterFunc(2*time.Second, func() { mu.Lock() defer mu.Unlock() logConnectedLocked() }) defer timer.Stop() - updatePeer := func(k key.NodePublic, isPresent bool) { + updatePeer := func(k key.NodePublic, ipPort netip.AddrPort, isPresent bool) { if isPresent { - add(k) + add(k, ipPort) } else { remove(k) } @@ -91,11 +92,11 @@ func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key } sleep := func(d time.Duration) { - t := time.NewTimer(d) + t, tChannel := c.clock.NewTimer(d) select { case <-ctx.Done(): t.Stop() - case <-t.C: + case <-tChannel: } } @@ -126,7 +127,7 @@ func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key } switch m := m.(type) { case derp.PeerPresentMessage: - updatePeer(key.NodePublic(m), true) + updatePeer(m.Key, m.IPPort, true) case derp.PeerGoneMessage: switch m.Reason { case derp.PeerGoneReasonDisconnected: @@ -138,11 +139,11 @@ func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key logf("Recv: peer %s not at server %s for unknown reason %v", key.NodePublic(m.Peer).ShortString(), c.ServerPublicKey().ShortString(), m.Reason) } - updatePeer(key.NodePublic(m.Peer), false) + updatePeer(key.NodePublic(m.Peer), netip.AddrPort{}, false) default: continue } - if now := time.Now(); now.Sub(lastStatus) > statusInterval { + if now := c.clock.Now(); now.Sub(lastStatus) > statusInterval { lastStatus = now infoLogf("%d peers", len(present)) } diff --git a/vendor/tailscale.com/derp/derphttp/websocket.go b/vendor/tailscale.com/derp/derphttp/websocket.go index 730f975ff7..08e4018540 100644 --- a/vendor/tailscale.com/derp/derphttp/websocket.go +++ b/vendor/tailscale.com/derp/derphttp/websocket.go @@ -27,6 +27,6 @@ func dialWebsocket(ctx context.Context, urlStr string) (net.Conn, error) { return nil, err } log.Printf("websocket: connected to %v", urlStr) - netConn := wsconn.NetConn(context.Background(), c, websocket.MessageBinary) + netConn := wsconn.NetConn(context.Background(), c, websocket.MessageBinary, urlStr) return netConn, nil } diff --git a/vendor/tailscale.com/disco/disco.go b/vendor/tailscale.com/disco/disco.go index 0e7c3f7e5f..46379b9d29 100644 --- a/vendor/tailscale.com/disco/disco.go +++ b/vendor/tailscale.com/disco/disco.go @@ -94,6 +94,9 @@ type Message interface { AppendMarshal([]byte) []byte } +// MessageHeaderLen is the length of a message header, 2 bytes for type and version. +const MessageHeaderLen = 2 + // appendMsgHeader appends two bytes (for t and ver) and then also // dataLen bytes to b, returning the appended slice in all. The // returned data slice is a subslice of all with just dataLen bytes of @@ -117,15 +120,24 @@ type Ping struct { // netmap data to reduce the discokey:nodekey relation from 1:N to // 1:1. NodeKey key.NodePublic + + // Padding is the number of 0 bytes at the end of the + // message. (It's used to probe path MTU.) + Padding int } +// PingLen is the length of a marshalled ping message, without the message +// header or padding. +const PingLen = 12 + key.NodePublicRawLen + func (m *Ping) AppendMarshal(b []byte) []byte { dataLen := 12 hasKey := !m.NodeKey.IsZero() if hasKey { dataLen += key.NodePublicRawLen } - ret, d := appendMsgHeader(b, TypePing, v0, dataLen) + + ret, d := appendMsgHeader(b, TypePing, v0, dataLen+m.Padding) n := copy(d, m.TxID[:]) if hasKey { m.NodeKey.AppendTo(d[:n]) @@ -138,11 +150,14 @@ func parsePing(ver uint8, p []byte) (m *Ping, err error) { return nil, errShort } m = new(Ping) + m.Padding = len(p) p = p[copy(m.TxID[:], p):] + m.Padding -= 12 // Deliberately lax on longer-than-expected messages, for future // compatibility. if len(p) >= key.NodePublicRawLen { m.NodeKey = key.NodePublicFromRaw32(mem.B(p[:key.NodePublicRawLen])) + m.Padding -= key.NodePublicRawLen } return m, nil } @@ -214,6 +229,8 @@ type Pong struct { Src netip.AddrPort // 18 bytes (16+2) on the wire; v4-mapped ipv6 for IPv4 } +// pongLen is the length of a marshalled pong message, without the message +// header or padding. const pongLen = 12 + 16 + 2 func (m *Pong) AppendMarshal(b []byte) []byte { diff --git a/vendor/tailscale.com/disco/pcap.go b/vendor/tailscale.com/disco/pcap.go new file mode 100644 index 0000000000..7103542486 --- /dev/null +++ b/vendor/tailscale.com/disco/pcap.go @@ -0,0 +1,40 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package disco + +import ( + "bytes" + "encoding/binary" + "net/netip" + + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// ToPCAPFrame marshals the bytes for a pcap record that describe a disco frame. +// +// Warning: Alloc garbage. Acceptable while capturing. +func ToPCAPFrame(src netip.AddrPort, derpNodeSrc key.NodePublic, payload []byte) []byte { + var ( + b bytes.Buffer + flag uint8 + ) + b.Grow(128) // Most disco frames will probably be smaller than this. + + if src.Addr() == tailcfg.DerpMagicIPAddr { + flag |= 0x01 + } + b.WriteByte(flag) // 1b: flag + + derpSrc := derpNodeSrc.Raw32() + b.Write(derpSrc[:]) // 32b: derp public key + binary.Write(&b, binary.LittleEndian, uint16(src.Port())) // 2b: port + addr, _ := src.Addr().MarshalBinary() + binary.Write(&b, binary.LittleEndian, uint16(len(addr))) // 2b: len(addr) + b.Write(addr) // Xb: addr + binary.Write(&b, binary.LittleEndian, uint16(len(payload))) // 2b: len(payload) + b.Write(payload) // Xb: payload + + return b.Bytes() +} diff --git a/vendor/tailscale.com/envknob/envknob.go b/vendor/tailscale.com/envknob/envknob.go index 654131e2a3..7bc2b6c6d7 100644 --- a/vendor/tailscale.com/envknob/envknob.go +++ b/vendor/tailscale.com/envknob/envknob.go @@ -389,12 +389,24 @@ func CanTaildrop() bool { return !Bool("TS_DISABLE_TAILDROP") } // SSHPolicyFile returns the path, if any, to the SSHPolicy JSON file for development. func SSHPolicyFile() string { return String("TS_DEBUG_SSH_POLICY_FILE") } -// SSHIgnoreTailnetPolicy is whether to ignore the Tailnet SSH policy for development. +// SSHIgnoreTailnetPolicy reports whether to ignore the Tailnet SSH policy for development. func SSHIgnoreTailnetPolicy() bool { return Bool("TS_DEBUG_SSH_IGNORE_TAILNET_POLICY") } -// TKASkipSignatureCheck is whether to skip node-key signature checking for development. +// TKASkipSignatureCheck reports whether to skip node-key signature checking for development. func TKASkipSignatureCheck() bool { return Bool("TS_UNSAFE_SKIP_NKS_VERIFICATION") } +// CrashOnUnexpected reports whether the Tailscale client should panic +// on unexpected conditions. If TS_DEBUG_CRASH_ON_UNEXPECTED is set, that's +// used. Otherwise the default value is true for unstable builds. +func CrashOnUnexpected() bool { + if v, ok := crashOnUnexpected().Get(); ok { + return v + } + return version.IsUnstableBuild() +} + +var crashOnUnexpected = RegisterOptBool("TS_DEBUG_CRASH_ON_UNEXPECTED") + // NoLogsNoSupport reports whether the client's opted out of log uploads and // technical support. func NoLogsNoSupport() bool { diff --git a/vendor/tailscale.com/flake.lock b/vendor/tailscale.com/flake.lock index 434501ed46..4e47c40e06 100644 --- a/vendor/tailscale.com/flake.lock +++ b/vendor/tailscale.com/flake.lock @@ -3,11 +3,11 @@ "flake-compat": { "flake": false, "locked": { - "lastModified": 1668681692, - "narHash": "sha256-Ht91NGdewz8IQLtWZ9LCeNXMSXHUss+9COoqu6JLmXU=", + "lastModified": 1673956053, + "narHash": "sha256-4gtG9iQuiKITOjNQQeQIpoIB6b16fm+504Ch3sNKLd8=", "owner": "edolstra", "repo": "flake-compat", - "rev": "009399224d5e398d03b22badca40a37ac85412a1", + "rev": "35bb57c0c8d8b62bbfd284272c928ceb64ddbde9", "type": "github" }, "original": { @@ -17,12 +17,15 @@ } }, "flake-utils": { + "inputs": { + "systems": "systems" + }, "locked": { - "lastModified": 1667395993, - "narHash": "sha256-nuEHfE/LcWyuSWnS8t12N1wc105Qtau+/OdUAjtQ0rA=", + "lastModified": 1692799911, + "narHash": "sha256-3eihraek4qL744EvQXsK1Ha6C3CR7nnT8X2qWap4RNk=", "owner": "numtide", "repo": "flake-utils", - "rev": "5aed5285a952e0b949eb3ba02c12fa4fcfef535f", + "rev": "f9e7cf818399d17d347f847525c5a5a8032e4e44", "type": "github" }, "original": { @@ -33,11 +36,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1675153841, - "narHash": "sha256-EWvU3DLq+4dbJiukfhS7r6sWZyJikgXn6kNl7eHljW8=", + "lastModified": 1693060755, + "narHash": "sha256-KNsbfqewEziFJEpPR0qvVz4rx0x6QXxw1CcunRhlFdk=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "ea692c2ad1afd6384e171eabef4f0887d2b882d3", + "rev": "c66ccfa00c643751da2fd9290e096ceaa30493fc", "type": "github" }, "original": { @@ -53,6 +56,21 @@ "flake-utils": "flake-utils", "nixpkgs": "nixpkgs" } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } } }, "root": "root", diff --git a/vendor/tailscale.com/flake.nix b/vendor/tailscale.com/flake.nix index c15d8a4672..643aa91ae3 100644 --- a/vendor/tailscale.com/flake.nix +++ b/vendor/tailscale.com/flake.nix @@ -70,7 +70,7 @@ # So really, this flake is for tailscale devs to dogfood with, if # you're an end user you should be prepared for this flake to not # build periodically. - tailscale = pkgs: pkgs.buildGo120Module rec { + tailscale = pkgs: pkgs.buildGo121Module rec { name = "tailscale"; src = ./.; @@ -107,7 +107,7 @@ gotools graphviz perl - go_1_20 + go_1_21 yarn ]; }; @@ -115,4 +115,4 @@ in flake-utils.lib.eachDefaultSystem (system: flakeForSystem nixpkgs system); } -# nix-direnv cache busting line: sha256-fgCrmtJs1svFz0Xn7iwLNrbBNlcO6V0yqGPMY0+V1VQ= +# nix-direnv cache busting line: sha256-aVtlDzC+sbEWlUAzPkAryA/+dqSzoAFc02xikh6yhf8= diff --git a/vendor/tailscale.com/go.mod.sri b/vendor/tailscale.com/go.mod.sri index 5e2b4b15f4..ce7a675fc4 100644 --- a/vendor/tailscale.com/go.mod.sri +++ b/vendor/tailscale.com/go.mod.sri @@ -1 +1 @@ -sha256-fgCrmtJs1svFz0Xn7iwLNrbBNlcO6V0yqGPMY0+V1VQ= +sha256-aVtlDzC+sbEWlUAzPkAryA/+dqSzoAFc02xikh6yhf8= diff --git a/vendor/tailscale.com/go.toolchain.branch b/vendor/tailscale.com/go.toolchain.branch index 0fde260c3d..bbb4d9628f 100644 --- a/vendor/tailscale.com/go.toolchain.branch +++ b/vendor/tailscale.com/go.toolchain.branch @@ -1 +1 @@ -tailscale.go1.20 +tailscale.go1.21 diff --git a/vendor/tailscale.com/go.toolchain.rev b/vendor/tailscale.com/go.toolchain.rev index 55b1e687d3..7a45a9b210 100644 --- a/vendor/tailscale.com/go.toolchain.rev +++ b/vendor/tailscale.com/go.toolchain.rev @@ -1 +1 @@ -40dc4d834a5fde9872bcf470be50069f56c3e3b3 +2071f43f327a8d544cd2df4b19398ed681e825c7 diff --git a/vendor/tailscale.com/health/health.go b/vendor/tailscale.com/health/health.go index bef416272a..e0881d8102 100644 --- a/vendor/tailscale.com/health/health.go +++ b/vendor/tailscale.com/health/health.go @@ -27,7 +27,7 @@ var ( sysErr = map[Subsystem]error{} // error key => err (or nil for no error) watchers = set.HandleSet[func(Subsystem, error)]{} // opt func to run if error state changes - warnables = map[*Warnable]struct{}{} // set of warnables + warnables = set.Set[*Warnable]{} timer *time.Timer debugHandler = map[string]http.Handler{} @@ -84,7 +84,7 @@ func NewWarnable(opts ...WarnableOpt) *Warnable { } mu.Lock() defer mu.Unlock() - warnables[w] = struct{}{} + warnables.Add(w) return w } @@ -279,27 +279,31 @@ func SetControlHealth(problems []string) { // GotStreamedMapResponse notes that we got a tailcfg.MapResponse // message in streaming mode, even if it's just a keep-alive message. +// +// This also notes that a map poll is in progress. To unset that, call +// SetOutOfPollNetMap(). func GotStreamedMapResponse() { mu.Lock() defer mu.Unlock() lastStreamedMapResponse = time.Now() + if !inMapPoll { + inMapPoll = true + inMapPollSince = time.Now() + } selfCheckLocked() } -// SetInPollNetMap records whether the client has an open -// HTTP long poll open to the control plane. -func SetInPollNetMap(v bool) { +// SetOutOfPollNetMap records that the client is no longer in +// an HTTP map request long poll to the control plane. +func SetOutOfPollNetMap() { mu.Lock() defer mu.Unlock() - if v == inMapPoll { + if !inMapPoll { return } - inMapPoll = v - if v { - inMapPollSince = time.Now() - } else { - lastMapPollEndedAt = time.Now() - } + inMapPoll = false + lastMapPollEndedAt = time.Now() + selfCheckLocked() } // GetInPollNetMap reports whether the client has an open diff --git a/vendor/tailscale.com/health/healthmsg/healthmsg.go b/vendor/tailscale.com/health/healthmsg/healthmsg.go index b0064547ad..e2915a1953 100644 --- a/vendor/tailscale.com/health/healthmsg/healthmsg.go +++ b/vendor/tailscale.com/health/healthmsg/healthmsg.go @@ -10,4 +10,5 @@ package healthmsg const ( WarnAcceptRoutesOff = "Some peers are advertising routes but --accept-routes is false" TailscaleSSHOnBut = "Tailscale SSH enabled, but " // + ... something from caller + LockedOut = "this node is locked out; it will not have connectivity until it is signed. For more info, see https://tailscale.com/s/locked-out" ) diff --git a/vendor/tailscale.com/hostinfo/hostinfo.go b/vendor/tailscale.com/hostinfo/hostinfo.go index 2280c6a5ae..0ae72f2956 100644 --- a/vendor/tailscale.com/hostinfo/hostinfo.go +++ b/vendor/tailscale.com/hostinfo/hostinfo.go @@ -141,15 +141,16 @@ func packageTypeCached() string { type EnvType string const ( - KNative = EnvType("kn") - AWSLambda = EnvType("lm") - Heroku = EnvType("hr") - AzureAppService = EnvType("az") - AWSFargate = EnvType("fg") - FlyDotIo = EnvType("fly") - Kubernetes = EnvType("k8s") - DockerDesktop = EnvType("dde") - Replit = EnvType("repl") + KNative = EnvType("kn") + AWSLambda = EnvType("lm") + Heroku = EnvType("hr") + AzureAppService = EnvType("az") + AWSFargate = EnvType("fg") + FlyDotIo = EnvType("fly") + Kubernetes = EnvType("k8s") + DockerDesktop = EnvType("dde") + Replit = EnvType("repl") + HomeAssistantAddOn = EnvType("haao") ) var envType atomic.Value // of EnvType @@ -170,6 +171,7 @@ var ( desktopAtomic atomic.Value // of opt.Bool packagingType atomic.Value // of string appType atomic.Value // of string + firewallMode atomic.Value // of string ) // SetPushDeviceToken sets the device token for use in Hostinfo updates. @@ -181,6 +183,9 @@ func SetDeviceModel(model string) { deviceModelAtomic.Store(model) } // SetOSVersion sets the OS version. func SetOSVersion(v string) { osVersionAtomic.Store(v) } +// SetFirewallMode sets the firewall mode for the app. +func SetFirewallMode(v string) { firewallMode.Store(v) } + // SetPackage sets the packaging type for the app. // // As of 2022-03-25, this is used by Android ("nogoogle" for the @@ -202,6 +207,13 @@ func pushDeviceToken() string { return s } +// FirewallMode returns the firewall mode for the app. +// It is empty if unset. +func FirewallMode() string { + s, _ := firewallMode.Load().(string) + return s +} + func desktop() (ret opt.Bool) { if runtime.GOOS != "linux" { return opt.Bool("") @@ -255,6 +267,9 @@ func getEnvType() EnvType { if inReplit() { return Replit } + if inHomeAssistantAddOn() { + return HomeAssistantAddOn + } return "" } @@ -283,7 +298,7 @@ func inContainer() opt.Bool { return nil }) lineread.File("/proc/mounts", func(line []byte) error { - if mem.Contains(mem.B(line), mem.S("fuse.lxcfs")) { + if mem.Contains(mem.B(line), mem.S("lxcfs /proc/cpuinfo fuse.lxcfs")) { ret.Set(true) return io.EOF } @@ -364,6 +379,13 @@ func inDockerDesktop() bool { return false } +func inHomeAssistantAddOn() bool { + if os.Getenv("SUPERVISOR_TOKEN") != "" || os.Getenv("HASSIO_TOKEN") != "" { + return true + } + return false +} + // goArchVar returns the GOARM or GOAMD64 etc value that the binary was built // with. func goArchVar() string { diff --git a/vendor/tailscale.com/ipn/backend.go b/vendor/tailscale.com/ipn/backend.go index 806598aa24..8da7e6a5c8 100644 --- a/vendor/tailscale.com/ipn/backend.go +++ b/vendor/tailscale.com/ipn/backend.go @@ -61,7 +61,7 @@ const ( // each one via RequestEngineStatus. NotifyWatchEngineUpdates NotifyWatchOpt = 1 << iota - NotifyInitialState // if set, the first Notify message (sent immediately) will contain the current State + BrowseToURL + NotifyInitialState // if set, the first Notify message (sent immediately) will contain the current State + BrowseToURL + SessionID NotifyInitialPrefs // if set, the first Notify message (sent immediately) will contain the current Prefs NotifyInitialNetMap // if set, the first Notify message (sent immediately) will contain the current NetMap @@ -77,6 +77,12 @@ type Notify struct { _ structs.Incomparable Version string // version number of IPN backend + // SessionID identifies the unique WatchIPNBus session. + // This field is only set in the first message when requesting + // NotifyInitialState. Clients must store it on their side as + // following notifications will not include this field. + SessionID string `json:",omitempty"` + // ErrMessage, if non-nil, contains a critical error message. // For State InUseOtherUser, ErrMessage is not critical and just contains the details. ErrMessage *string diff --git a/vendor/tailscale.com/ipn/ipn_clone.go b/vendor/tailscale.com/ipn/ipn_clone.go index 97207d0398..68c942dfbf 100644 --- a/vendor/tailscale.com/ipn/ipn_clone.go +++ b/vendor/tailscale.com/ipn/ipn_clone.go @@ -6,6 +6,7 @@ package ipn import ( + "maps" "net/netip" "tailscale.com/tailcfg" @@ -50,6 +51,7 @@ var _PrefsCloneNeedsRegeneration = Prefs(struct { NetfilterMode preftype.NetfilterMode OperatorUser string ProfileName string + AutoUpdate AutoUpdatePrefs Persist *persist.Persist }{}) @@ -73,10 +75,11 @@ func (src *ServeConfig) Clone() *ServeConfig { dst.Web[k] = v.Clone() } } - if dst.AllowFunnel != nil { - dst.AllowFunnel = map[HostPort]bool{} - for k, v := range src.AllowFunnel { - dst.AllowFunnel[k] = v + dst.AllowFunnel = maps.Clone(src.AllowFunnel) + if dst.Foreground != nil { + dst.Foreground = map[string]*ServeConfig{} + for k, v := range src.Foreground { + dst.Foreground[k] = v.Clone() } } return dst @@ -87,6 +90,8 @@ var _ServeConfigCloneNeedsRegeneration = ServeConfig(struct { TCP map[uint16]*TCPPortHandler Web map[HostPort]*WebServerConfig AllowFunnel map[HostPort]bool + Foreground map[string]*ServeConfig + ETag string }{}) // Clone makes a deep copy of TCPPortHandler. diff --git a/vendor/tailscale.com/ipn/ipn_view.go b/vendor/tailscale.com/ipn/ipn_view.go index 1abeb6709d..dbbf374768 100644 --- a/vendor/tailscale.com/ipn/ipn_view.go +++ b/vendor/tailscale.com/ipn/ipn_view.go @@ -79,13 +79,14 @@ func (v PrefsView) Hostname() string { return v.ж.Hostname } func (v PrefsView) NotepadURLs() bool { return v.ж.NotepadURLs } func (v PrefsView) ForceDaemon() bool { return v.ж.ForceDaemon } func (v PrefsView) Egg() bool { return v.ж.Egg } -func (v PrefsView) AdvertiseRoutes() views.IPPrefixSlice { - return views.IPPrefixSliceOf(v.ж.AdvertiseRoutes) +func (v PrefsView) AdvertiseRoutes() views.Slice[netip.Prefix] { + return views.SliceOf(v.ж.AdvertiseRoutes) } func (v PrefsView) NoSNAT() bool { return v.ж.NoSNAT } func (v PrefsView) NetfilterMode() preftype.NetfilterMode { return v.ж.NetfilterMode } func (v PrefsView) OperatorUser() string { return v.ж.OperatorUser } func (v PrefsView) ProfileName() string { return v.ж.ProfileName } +func (v PrefsView) AutoUpdate() AutoUpdatePrefs { return v.ж.AutoUpdate } func (v PrefsView) Persist() persist.PersistView { return v.ж.Persist.View() } // A compilation failure here means this code must be regenerated, with the command at the top of this file. @@ -111,6 +112,7 @@ var _PrefsViewNeedsRegeneration = Prefs(struct { NetfilterMode preftype.NetfilterMode OperatorUser string ProfileName string + AutoUpdate AutoUpdatePrefs Persist *persist.Persist }{}) @@ -175,11 +177,20 @@ func (v ServeConfigView) AllowFunnel() views.Map[HostPort, bool] { return views.MapOf(v.ж.AllowFunnel) } +func (v ServeConfigView) Foreground() views.MapFn[string, *ServeConfig, ServeConfigView] { + return views.MapFnOf(v.ж.Foreground, func(t *ServeConfig) ServeConfigView { + return t.View() + }) +} +func (v ServeConfigView) ETag() string { return v.ж.ETag } + // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _ServeConfigViewNeedsRegeneration = ServeConfig(struct { TCP map[uint16]*TCPPortHandler Web map[HostPort]*WebServerConfig AllowFunnel map[HostPort]bool + Foreground map[string]*ServeConfig + ETag string }{}) // View returns a readonly view of TCPPortHandler. diff --git a/vendor/tailscale.com/ipn/ipnlocal/breaktcp_darwin.go b/vendor/tailscale.com/ipn/ipnlocal/breaktcp_darwin.go new file mode 100644 index 0000000000..13566198ce --- /dev/null +++ b/vendor/tailscale.com/ipn/ipnlocal/breaktcp_darwin.go @@ -0,0 +1,30 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "log" + + "golang.org/x/sys/unix" +) + +func init() { + breakTCPConns = breakTCPConnsDarwin +} + +func breakTCPConnsDarwin() error { + var matched int + for fd := 0; fd < 1000; fd++ { + _, err := unix.GetsockoptTCPConnectionInfo(fd, unix.IPPROTO_TCP, unix.TCP_CONNECTION_INFO) + if err == nil { + matched++ + err = unix.Close(fd) + log.Printf("debug: closed TCP fd %v: %v", fd, err) + } + } + if matched == 0 { + log.Printf("debug: no TCP connections found") + } + return nil +} diff --git a/vendor/tailscale.com/ipn/ipnlocal/breaktcp_linux.go b/vendor/tailscale.com/ipn/ipnlocal/breaktcp_linux.go new file mode 100644 index 0000000000..b82f652124 --- /dev/null +++ b/vendor/tailscale.com/ipn/ipnlocal/breaktcp_linux.go @@ -0,0 +1,30 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "log" + + "golang.org/x/sys/unix" +) + +func init() { + breakTCPConns = breakTCPConnsLinux +} + +func breakTCPConnsLinux() error { + var matched int + for fd := 0; fd < 1000; fd++ { + _, err := unix.GetsockoptTCPInfo(fd, unix.IPPROTO_TCP, unix.TCP_INFO) + if err == nil { + matched++ + err = unix.Close(fd) + log.Printf("debug: closed TCP fd %v: %v", fd, err) + } + } + if matched == 0 { + log.Printf("debug: no TCP connections found") + } + return nil +} diff --git a/vendor/tailscale.com/ipn/ipnlocal/c2n.go b/vendor/tailscale.com/ipn/ipnlocal/c2n.go index 54428d38e5..cae75df963 100644 --- a/vendor/tailscale.com/ipn/ipnlocal/c2n.go +++ b/vendor/tailscale.com/ipn/ipnlocal/c2n.go @@ -4,6 +4,7 @@ package ipnlocal import ( + "bytes" "encoding/json" "errors" "fmt" @@ -14,17 +15,20 @@ import ( "path/filepath" "runtime" "strconv" + "strings" "time" + "tailscale.com/clientupdate" "tailscale.com/envknob" "tailscale.com/net/sockstats" "tailscale.com/tailcfg" "tailscale.com/util/clientmetric" "tailscale.com/util/goroutines" "tailscale.com/version" - "tailscale.com/version/distro" ) +var c2nLogHeap func(http.ResponseWriter, *http.Request) // non-nil on most platforms (c2n_pprof.go) + func (b *LocalBackend) handleC2N(w http.ResponseWriter, r *http.Request) { writeJSON := func(v any) { w.Header().Set("Content-Type", "application/json") @@ -36,7 +40,15 @@ func (b *LocalBackend) handleC2N(w http.ResponseWriter, r *http.Request) { body, _ := io.ReadAll(r.Body) w.Write(body) case "/update": - b.handleC2NUpdate(w, r) + switch r.Method { + case http.MethodGet: + b.handleC2NUpdateGet(w, r) + case http.MethodPost: + b.handleC2NUpdatePost(w, r) + default: + http.Error(w, "bad method", http.StatusMethodNotAllowed) + return + } case "/logtail/flush": if r.Method != "POST" { http.Error(w, "bad method", http.StatusMethodNotAllowed) @@ -49,7 +61,7 @@ func (b *LocalBackend) handleC2N(w http.ResponseWriter, r *http.Request) { } case "/debug/goroutines": w.Header().Set("Content-Type", "text/plain") - w.Write(goroutines.ScrubbedGoroutineDump()) + w.Write(goroutines.ScrubbedGoroutineDump(true)) case "/debug/prefs": writeJSON(b.Prefs()) case "/debug/metrics": @@ -61,7 +73,7 @@ func (b *LocalBackend) handleC2N(w http.ResponseWriter, r *http.Request) { if secs == 0 { secs -= 1 } - until := time.Now().Add(time.Duration(secs) * time.Second) + until := b.clock.Now().Add(time.Duration(secs) * time.Second) err := b.SetComponentDebugLogging(component, until) var res struct { Error string `json:",omitempty"` @@ -70,6 +82,13 @@ func (b *LocalBackend) handleC2N(w http.ResponseWriter, r *http.Request) { res.Error = err.Error() } writeJSON(res) + case "/debug/logheap": + if c2nLogHeap != nil { + c2nLogHeap(w, r) + } else { + http.Error(w, "not implemented", http.StatusNotImplemented) + return + } case "/ssh/usernames": var req tailcfg.C2NSSHUsernamesRequest if r.Method == "POST" { @@ -102,46 +121,48 @@ func (b *LocalBackend) handleC2N(w http.ResponseWriter, r *http.Request) { } } -func (b *LocalBackend) handleC2NUpdate(w http.ResponseWriter, r *http.Request) { - // TODO(bradfitz): add some sort of semaphore that prevents two concurrent - // updates, or if one happened in the past 5 minutes, or something. +func (b *LocalBackend) handleC2NUpdateGet(w http.ResponseWriter, r *http.Request) { + b.logf("c2n: GET /update received") - // TODO(bradfitz): move this type to some leaf package - type updateResponse struct { - Err string // error message, if any - Enabled bool // user has opted-in to remote updates - Supported bool // Tailscale supports updating this OS/platform - Started bool - } - var res updateResponse - res.Enabled = envknob.AllowsRemoteUpdate() - res.Supported = runtime.GOOS == "windows" || (runtime.GOOS == "linux" && distro.Get() == distro.Debian) + res := b.newC2NUpdateResponse() + res.Started = b.c2nUpdateStarted() - switch r.Method { - case "GET", "POST": - default: - http.Error(w, "bad method", http.StatusMethodNotAllowed) - return - } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(res) +} +func (b *LocalBackend) handleC2NUpdatePost(w http.ResponseWriter, r *http.Request) { + b.logf("c2n: POST /update received") + res := b.newC2NUpdateResponse() defer func() { + if res.Err != "" { + b.logf("c2n: POST /update failed: %s", res.Err) + } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(res) }() - if r.Method == "GET" { - return - } - if !res.Enabled { res.Err = "not enabled" return } - if !res.Supported { res.Err = "not supported" return } + + // Check if update was already started, and mark as started. + if !b.trySetC2NUpdateStarted() { + res.Err = "update already started" + return + } + defer func() { + // Clear the started flag if something failed. + if res.Err != "" { + b.setC2NUpdateStarted(false) + } + }() + cmdTS, err := findCmdTailscale() if err != nil { res.Err = fmt.Sprintf("failed to find cmd/tailscale binary: %v", err) @@ -163,22 +184,64 @@ func (b *LocalBackend) handleC2NUpdate(w http.ResponseWriter, r *http.Request) { res.Err = "cmd/tailscale version mismatch" return } + cmd := exec.Command(cmdTS, "update", "--yes") + buf := new(bytes.Buffer) + cmd.Stdout = buf + cmd.Stderr = buf + b.logf("c2n: running %q", strings.Join(cmd.Args, " ")) if err := cmd.Start(); err != nil { res.Err = fmt.Sprintf("failed to start cmd/tailscale update: %v", err) return } res.Started = true - // TODO(bradfitz,andrew): There might be a race condition here on Windows: - // * We start the update process. - // * tailscale.exe copies itself and kicks off the update process - // * msiexec stops this process during the update before the selfCopy exits(?) - // * This doesn't return because the process is dead. + // Run update asynchronously and respond that it started. + go func() { + if err := cmd.Wait(); err != nil { + b.logf("c2n: update command failed: %v, output: %s", err, buf) + } else { + b.logf("c2n: update complete") + } + b.setC2NUpdateStarted(false) + }() +} + +func (b *LocalBackend) newC2NUpdateResponse() tailcfg.C2NUpdateResponse { + // If NewUpdater does not return an error, we can update the installation. + // Exception: When version.IsMacSysExt returns true, we don't support that + // yet. TODO(cpalmer, #6995): Implement it. // - // This seems fairly unlikely, but worth checking. - defer cmd.Wait() - return + // Note that we create the Updater solely to check for errors; we do not + // invoke it here. For this purpose, it is ok to pass it a zero Arguments. + prefs := b.Prefs().AutoUpdate() + _, err := clientupdate.NewUpdater(clientupdate.Arguments{}) + return tailcfg.C2NUpdateResponse{ + Enabled: envknob.AllowsRemoteUpdate() || prefs.Apply, + Supported: err == nil && !version.IsMacSysExt(), + } +} + +func (b *LocalBackend) c2nUpdateStarted() bool { + b.mu.Lock() + defer b.mu.Unlock() + return b.c2nUpdateStatus.started +} + +func (b *LocalBackend) setC2NUpdateStarted(v bool) { + b.mu.Lock() + defer b.mu.Unlock() + b.c2nUpdateStatus.started = v +} + +func (b *LocalBackend) trySetC2NUpdateStarted() bool { + b.mu.Lock() + defer b.mu.Unlock() + if b.c2nUpdateStatus.started { + return false + } + b.c2nUpdateStatus.started = true + return true } // findCmdTailscale looks for the cmd/tailscale that corresponds to the diff --git a/vendor/tailscale.com/ipn/ipnlocal/c2n_pprof.go b/vendor/tailscale.com/ipn/ipnlocal/c2n_pprof.go new file mode 100644 index 0000000000..9341548ee8 --- /dev/null +++ b/vendor/tailscale.com/ipn/ipnlocal/c2n_pprof.go @@ -0,0 +1,17 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !js && !wasm + +package ipnlocal + +import ( + "net/http" + "runtime/pprof" +) + +func init() { + c2nLogHeap = func(w http.ResponseWriter, r *http.Request) { + pprof.WriteHeapProfile(w) + } +} diff --git a/vendor/tailscale.com/ipn/ipnlocal/cert.go b/vendor/tailscale.com/ipn/ipnlocal/cert.go index 33bbac8df1..de5dee3e96 100644 --- a/vendor/tailscale.com/ipn/ipnlocal/cert.go +++ b/vendor/tailscale.com/ipn/ipnlocal/cert.go @@ -22,16 +22,17 @@ import ( "fmt" "io" "log" + insecurerand "math/rand" "net" "os" "path/filepath" "runtime" + "slices" "strings" "sync" "time" - "golang.org/x/crypto/acme" - "golang.org/x/exp/slices" + "github.com/tailscale/golang-x-crypto/acme" "tailscale.com/atomicfile" "tailscale.com/envknob" "tailscale.com/hostinfo" @@ -52,8 +53,8 @@ var ( // populate the on-disk cache and the rest should use that. acmeMu sync.Mutex - renewMu sync.Mutex // lock order: don't hold acmeMu and renewMu at the same time - lastRenewCheck = map[string]time.Time{} + renewMu sync.Mutex // lock order: acmeMu before renewMu + renewCertAt = map[string]time.Time{} ) // certDir returns (creating if needed) the directory in which cached @@ -79,14 +80,20 @@ func (b *LocalBackend) certDir() (string, error) { var acmeDebug = envknob.RegisterBool("TS_DEBUG_ACME") -// getCertPEM gets the KeyPair for domain, either from cache, via the ACME -// process, or from cache and kicking off an async ACME renewal. -func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) { +// GetCertPEM gets the TLSCertKeyPair for domain, either from cache or via the +// ACME process. ACME process is used for new domain certs, existing expired +// certs or existing certs that should get renewed due to upcoming expiry. +// +// syncRenewal changes renewal behavior for existing certs that are still valid +// but need renewal. When syncRenewal is set, the method blocks until a new +// cert is issued. When syncRenewal is not set, existing cert is returned right +// away and renewal is kicked off in a background goroutine. +func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) { if !validLookingCertDomain(domain) { return nil, errors.New("invalid domain") } logf := logger.WithPrefix(b.logf, fmt.Sprintf("cert(%q): ", domain)) - now := time.Now() + now := b.clock.Now() traceACME := func(v any) { if !acmeDebug() { return @@ -101,15 +108,18 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK } if pair, err := getCertPEMCached(cs, domain, now); err == nil { - shouldRenew, err := shouldStartDomainRenewal(domain, now, pair) + shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair) if err != nil { logf("error checking for certificate renewal: %v", err) - } else if shouldRenew { + } else if !shouldRenew { + return pair, nil + } + if !syncRenewal { logf("starting async renewal") // Start renewal in the background. go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now) } - return pair, nil + // Synchronous renewal happens below. } pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now) @@ -120,28 +130,46 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK return pair, nil } -func shouldStartDomainRenewal(domain string, now time.Time, pair *TLSCertKeyPair) (bool, error) { +func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, now time.Time, pair *TLSCertKeyPair) (bool, error) { renewMu.Lock() defer renewMu.Unlock() - if last, ok := lastRenewCheck[domain]; ok && now.Sub(last) < time.Minute { - // We checked very recently. Don't bother reparsing & - // validating the x509 cert. - return false, nil + if renewAt, ok := renewCertAt[domain]; ok { + return now.After(renewAt), nil + } + + renewTime, err := b.domainRenewalTimeByARI(cs, pair) + if err != nil { + // Log any ARI failure and fall back to checking for renewal by expiry. + b.logf("acme: ARI check failed: %v; falling back to expiry-based check", err) + renewTime, err = b.domainRenewalTimeByExpiry(pair) + if err != nil { + return false, err + } } - lastRenewCheck[domain] = now + renewCertAt[domain] = renewTime + return now.After(renewTime), nil +} + +func (b *LocalBackend) domainRenewed(domain string) { + renewMu.Lock() + defer renewMu.Unlock() + delete(renewCertAt, domain) +} + +func (b *LocalBackend) domainRenewalTimeByExpiry(pair *TLSCertKeyPair) (time.Time, error) { block, _ := pem.Decode(pair.CertPEM) if block == nil { - return false, fmt.Errorf("parsing certificate PEM") + return time.Time{}, fmt.Errorf("parsing certificate PEM") } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { - return false, fmt.Errorf("parsing certificate: %w", err) + return time.Time{}, fmt.Errorf("parsing certificate: %w", err) } certLifetime := cert.NotAfter.Sub(cert.NotBefore) if certLifetime < 0 { - return false, fmt.Errorf("negative certificate lifetime %v", certLifetime) + return time.Time{}, fmt.Errorf("negative certificate lifetime %v", certLifetime) } // Per https://github.com/tailscale/tailscale/issues/8204, check @@ -150,11 +178,43 @@ func shouldStartDomainRenewal(domain string, now time.Time, pair *TLSCertKeyPair // Encrypt. renewalDuration := certLifetime * 2 / 3 renewAt := cert.NotBefore.Add(renewalDuration) + return renewAt, nil +} - if now.After(renewAt) { - return true, nil +func (b *LocalBackend) domainRenewalTimeByARI(cs certStore, pair *TLSCertKeyPair) (time.Time, error) { + var blocks []*pem.Block + rest := pair.CertPEM + for len(rest) > 0 { + var block *pem.Block + block, rest = pem.Decode(rest) + if block == nil { + return time.Time{}, fmt.Errorf("parsing certificate PEM") + } + blocks = append(blocks, block) + } + if len(blocks) < 2 { + return time.Time{}, fmt.Errorf("could not parse certificate chain from certStore, got %d PEM block(s)", len(blocks)) + } + ac, err := acmeClient(cs) + if err != nil { + return time.Time{}, err + } + ctx, cancel := context.WithTimeout(b.ctx, 5*time.Second) + defer cancel() + ri, err := ac.FetchRenewalInfo(ctx, blocks[0].Bytes, blocks[1].Bytes) + if err != nil { + return time.Time{}, fmt.Errorf("failed to fetch renewal info from ACME server: %w", err) } - return false, nil + if acmeDebug() { + b.logf("acme: ARI response: %+v", ri) + } + + // Select a random time in the suggested window and renew if that time has + // passed. Time is randomized per recommendation in + // https://datatracker.ietf.org/doc/draft-ietf-acme-ari/ + start, end := ri.SuggestedWindow.Start, ri.SuggestedWindow.End + renewTime := start.Add(time.Duration(insecurerand.Int63n(int64(end.Sub(start))))) + return renewTime, nil } // certStore provides a way to perist and retrieve TLS certificates. @@ -279,11 +339,11 @@ func (s certStateStore) Read(domain string, now time.Time) (*TLSCertKeyPair, err } func (s certStateStore) WriteCert(domain string, cert []byte) error { - return s.WriteState(ipn.StateKey(domain+".crt"), cert) + return ipn.WriteState(s.StateStore, ipn.StateKey(domain+".crt"), cert) } func (s certStateStore) WriteKey(domain string, key []byte) error { - return s.WriteState(ipn.StateKey(domain+".key"), key) + return ipn.WriteState(s.StateStore, ipn.StateKey(domain+".key"), key) } func (s certStateStore) ACMEKey() ([]byte, error) { @@ -291,7 +351,7 @@ func (s certStateStore) ACMEKey() ([]byte, error) { } func (s certStateStore) WriteACMEKey(key []byte) error { - return s.WriteState(ipn.StateKey(acmePEMName), key) + return ipn.WriteState(s.StateStore, ipn.StateKey(acmePEMName), key) } // TLSCertKeyPair is a TLS public and private key, and whether they were obtained @@ -322,19 +382,25 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger acmeMu.Lock() defer acmeMu.Unlock() + // In case this method was triggered multiple times in parallel (when + // serving incoming requests), check whether one of the other goroutines + // already renewed the cert before us. if p, err := getCertPEMCached(cs, domain, now); err == nil { - return p, nil + // shouldStartDomainRenewal caches its result so it's OK to call this + // frequently. + shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, p) + if err != nil { + logf("error checking for certificate renewal: %v", err) + } else if !shouldRenew { + return p, nil + } } else if !errors.Is(err, ipn.ErrStateNotExist) && !errors.Is(err, errCertExpired) { return nil, err } - key, err := acmeKey(cs) + ac, err := acmeClient(cs) if err != nil { - return nil, fmt.Errorf("acmeKey: %w", err) - } - ac := &acme.Client{ - Key: key, - UserAgent: "tailscaled/" + version.Long(), + return nil, err } a, err := ac.GetReg(ctx, "" /* pre-RFC param */) @@ -464,6 +530,7 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger if err := cs.WriteCert(domain, certPEM.Bytes()); err != nil { return nil, err } + b.domainRenewed(domain) return &TLSCertKeyPair{CertPEM: certPEM.Bytes(), KeyPEM: privPEM.Bytes()}, nil } @@ -540,6 +607,20 @@ func acmeKey(cs certStore) (crypto.Signer, error) { return privKey, nil } +func acmeClient(cs certStore) (*acme.Client, error) { + key, err := acmeKey(cs) + if err != nil { + return nil, fmt.Errorf("acmeKey: %w", err) + } + // Note: if we add support for additional ACME providers (other than + // LetsEncrypt), we should make sure that they support ARI extension (see + // shouldStartDomainRenewalARI). + return &acme.Client{ + Key: key, + UserAgent: "tailscaled/" + version.Long(), + }, nil +} + // validCertPEM reports whether the given certificate is valid for domain at now. // // If roots != nil, it is used instead of the system root pool. This is meant diff --git a/vendor/tailscale.com/ipn/ipnlocal/cert_js.go b/vendor/tailscale.com/ipn/ipnlocal/cert_js.go index a5fdfc4ba2..24defb47bf 100644 --- a/vendor/tailscale.com/ipn/ipnlocal/cert_js.go +++ b/vendor/tailscale.com/ipn/ipnlocal/cert_js.go @@ -12,6 +12,6 @@ type TLSCertKeyPair struct { CertPEM, KeyPEM []byte } -func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) { +func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) { return nil, errors.New("not implemented for js/wasm") } diff --git a/vendor/tailscale.com/ipn/ipnlocal/expiry.go b/vendor/tailscale.com/ipn/ipnlocal/expiry.go index 0df30bf567..13e57d3275 100644 --- a/vendor/tailscale.com/ipn/ipnlocal/expiry.go +++ b/vendor/tailscale.com/ipn/ipnlocal/expiry.go @@ -8,6 +8,7 @@ import ( "tailscale.com/syncs" "tailscale.com/tailcfg" + "tailscale.com/tstime" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netmap" @@ -37,22 +38,22 @@ type expiryManager struct { // time.Now().Add(clockDelta) == MapResponse.ControlTime clockDelta syncs.AtomicValue[time.Duration] - logf logger.Logf - timeNow func() time.Time + logf logger.Logf + clock tstime.Clock } func newExpiryManager(logf logger.Logf) *expiryManager { return &expiryManager{ previouslyExpired: map[tailcfg.StableNodeID]bool{}, logf: logf, - timeNow: time.Now, + clock: tstime.StdClock{}, } } // onControlTime is called whenever we receive a new timestamp from the control // server to store the delta. func (em *expiryManager) onControlTime(t time.Time) { - localNow := em.timeNow() + localNow := em.clock.Now() delta := t.Sub(localNow) if delta.Abs() > minClockDelta { em.logf("[v1] netmap: flagExpiredPeers: setting clock delta to %v", delta) @@ -86,24 +87,26 @@ func (em *expiryManager) flagExpiredPeers(netmap *netmap.NetworkMap, localNow ti return } - for _, peer := range netmap.Peers { + for i, peer := range netmap.Peers { // Nodes that don't expire have KeyExpiry set to the zero time; // skip those and peers that are already marked as expired // (e.g. from control). - if peer.KeyExpiry.IsZero() || peer.KeyExpiry.After(controlNow) { - delete(em.previouslyExpired, peer.StableID) + if peer.KeyExpiry().IsZero() || peer.KeyExpiry().After(controlNow) { + delete(em.previouslyExpired, peer.StableID()) continue - } else if peer.Expired { + } else if peer.Expired() { continue } - if !em.previouslyExpired[peer.StableID] { - em.logf("[v1] netmap: flagExpiredPeers: clearing expired peer %v", peer.StableID) - em.previouslyExpired[peer.StableID] = true + if !em.previouslyExpired[peer.StableID()] { + em.logf("[v1] netmap: flagExpiredPeers: clearing expired peer %v", peer.StableID()) + em.previouslyExpired[peer.StableID()] = true } + mut := peer.AsStruct() + // Actually mark the node as expired - peer.Expired = true + mut.Expired = true // Control clears the Endpoints and DERP fields of expired // nodes; do so here as well. The Expired bool is the correct @@ -112,12 +115,14 @@ func (em *expiryManager) flagExpiredPeers(netmap *netmap.NetworkMap, localNow ti // NOTE: this is insufficient to actually break connectivity, // since we discover endpoints via DERP, and due to DERP return // path optimization. - peer.Endpoints = nil - peer.DERP = "" + mut.Endpoints = nil + mut.DERP = "" // Defense-in-depth: break the node's public key as well, in // case something tries to communicate. - peer.Key = key.NodePublicWithBadOldPrefix(peer.Key) + mut.Key = key.NodePublicWithBadOldPrefix(peer.Key()) + + netmap.Peers[i] = mut.View() } } @@ -143,13 +148,13 @@ func (em *expiryManager) nextPeerExpiry(nm *netmap.NetworkMap, localNow time.Tim var nextExpiry time.Time // zero if none for _, peer := range nm.Peers { - if peer.KeyExpiry.IsZero() { + if peer.KeyExpiry().IsZero() { continue // tagged node - } else if peer.Expired { + } else if peer.Expired() { // Peer already expired; Expired is set by the // flagExpiredPeers function, above. continue - } else if peer.KeyExpiry.Before(controlNow) { + } else if peer.KeyExpiry().Before(controlNow) { // This peer already expired, and peer.Expired // isn't set for some reason. Skip this node. continue @@ -159,14 +164,14 @@ func (em *expiryManager) nextPeerExpiry(nm *netmap.NetworkMap, localNow time.Tim // an expiry; otherwise, only update if this node's expiry is // sooner than the currently-stored one (since we want the // soonest-occurring expiry time). - if nextExpiry.IsZero() || peer.KeyExpiry.Before(nextExpiry) { - nextExpiry = peer.KeyExpiry + if nextExpiry.IsZero() || peer.KeyExpiry().Before(nextExpiry) { + nextExpiry = peer.KeyExpiry() } } // Ensure that we also fire this timer if our own node key expires. - if nm.SelfNode != nil { - selfExpiry := nm.SelfNode.KeyExpiry + if nm.SelfNode.Valid() { + selfExpiry := nm.SelfNode.KeyExpiry() if selfExpiry.IsZero() { // No expiry for self node diff --git a/vendor/tailscale.com/ipn/ipnlocal/local.go b/vendor/tailscale.com/ipn/ipnlocal/local.go index 762dfba9b2..97440bc81b 100644 --- a/vendor/tailscale.com/ipn/ipnlocal/local.go +++ b/vendor/tailscale.com/ipn/ipnlocal/local.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "log" + "maps" "net" "net/http" "net/http/httputil" @@ -20,6 +21,7 @@ import ( "os/user" "path/filepath" "runtime" + "slices" "sort" "strconv" "strings" @@ -29,10 +31,11 @@ import ( "go4.org/mem" "go4.org/netipx" - "golang.org/x/exp/slices" + xmaps "golang.org/x/exp/maps" "gvisor.dev/gvisor/pkg/tcpip" "tailscale.com/client/tailscale/apitype" "tailscale.com/control/controlclient" + "tailscale.com/control/controlknobs" "tailscale.com/doctor" "tailscale.com/doctor/permissions" "tailscale.com/doctor/routetable" @@ -50,6 +53,7 @@ import ( "tailscale.com/net/dnscache" "tailscale.com/net/dnsfallback" "tailscale.com/net/interfaces" + "tailscale.com/net/netmon" "tailscale.com/net/netns" "tailscale.com/net/netutil" "tailscale.com/net/tsaddr" @@ -60,6 +64,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/tka" "tailscale.com/tsd" + "tailscale.com/tstime" "tailscale.com/types/dnstype" "tailscale.com/types/empty" "tailscale.com/types/key" @@ -76,8 +81,10 @@ import ( "tailscale.com/util/mak" "tailscale.com/util/multierr" "tailscale.com/util/osshare" + "tailscale.com/util/rands" "tailscale.com/util/set" "tailscale.com/util/systemd" + "tailscale.com/util/testenv" "tailscale.com/util/uniq" "tailscale.com/version" "tailscale.com/version/distro" @@ -140,19 +147,17 @@ type LocalBackend struct { statsLogf logger.Logf // for printing peers stats on change sys *tsd.System e wgengine.Engine // non-nil; TODO(bradfitz): remove; use sys - pm *profileManager - store ipn.StateStore // non-nil; TODO(bradfitz): remove; use sys - dialer *tsdial.Dialer // non-nil; TODO(bradfitz): remove; use sys + store ipn.StateStore // non-nil; TODO(bradfitz): remove; use sys + dialer *tsdial.Dialer // non-nil; TODO(bradfitz): remove; use sys backendLogID logid.PublicID unregisterNetMon func() unregisterHealthWatch func() portpoll *portlist.Poller // may be nil portpollOnce sync.Once // guards starting readPoller gotPortPollRes chan struct{} // closed upon first readPoller result - newDecompressor func() (controlclient.Decompressor, error) - varRoot string // or empty if SetVarRoot never called - logFlushFunc func() // or nil if SetLogFlusher wasn't called - em *expiryManager // non-nil + varRoot string // or empty if SetVarRoot never called + logFlushFunc func() // or nil if SetLogFlusher wasn't called + em *expiryManager // non-nil sshAtomicBool atomic.Bool shutdownCalled bool // if Shutdown has been called debugSink *capture.Sink @@ -185,6 +190,7 @@ type LocalBackend struct { // The mutex protects the following elements. mu sync.Mutex + pm *profileManager // mu guards access filterHash deephash.Sum httpTestClient *http.Client // for controlclient. nil by default, used by tests. ccGen clientGen // function for producing controlclient; lazily populated @@ -199,11 +205,20 @@ type LocalBackend struct { capTailnetLock bool // whether netMap contains the tailnet lock capability // hostinfo is mutated in-place while mu is held. hostinfo *tailcfg.Hostinfo - // netMap is not mutated in-place once set. - netMap *netmap.NetworkMap - nmExpiryTimer *time.Timer // for updating netMap on node expiry; can be nil - nodeByAddr map[netip.Addr]*tailcfg.Node - activeLogin string // last logged LoginName from netMap + // netMap is the most recently set full netmap from the controlclient. + // It can't be mutated in place once set. Because it can't be mutated in place, + // delta updates from the control server don't apply to it. Instead, use + // the peers map to get up-to-date information on the state of peers. + // In general, avoid using the netMap.Peers slice. We'd like it to go away + // as of 2023-09-17. + netMap *netmap.NetworkMap + // peers is the set of current peers and their current values after applying + // delta node mutations as they come in (with mu held). The map values can + // be given out to callers, but the map itself must not escape the LocalBackend. + peers map[tailcfg.NodeID]tailcfg.NodeView + nodeByAddr map[netip.Addr]tailcfg.NodeID + nmExpiryTimer tstime.TimerController // for updating netMap on node expiry; can be nil + activeLogin string // last logged LoginName from netMap engineStatus ipn.EngineStatus endpoints []tailcfg.Endpoint blocked bool @@ -235,10 +250,13 @@ type LocalBackend struct { directFileRoot string directFileDoFinalRename bool // false on macOS, true on several NAS platforms componentLogUntil map[string]componentLogState + // c2nUpdateStatus is the status of c2n-triggered client update. + c2nUpdateStatus updateStatus // ServeConfig fields. (also guarded by mu) - lastServeConfJSON mem.RO // last JSON that was parsed into serveConfig - serveConfig ipn.ServeConfigView // or !Valid if none + lastServeConfJSON mem.RO // last JSON that was parsed into serveConfig + serveConfig ipn.ServeConfigView // or !Valid if none + activeWatchSessions set.Set[string] // of WatchIPN SessionID serveListeners map[netip.AddrPort]*serveListener // addrPort => serveListener serveProxyHandlers sync.Map // string (HTTPHandler.Proxy) => *httputil.ReverseProxy @@ -259,6 +277,14 @@ type LocalBackend struct { // tkaSyncLock MUST be taken before mu (or inversely, mu must not be held // at the moment that tkaSyncLock is taken). tkaSyncLock sync.Mutex + clock tstime.Clock + + // Last ClientVersion received in MapResponse, guarded by mu. + lastClientVersion *tailcfg.ClientVersion +} + +type updateStatus struct { + started bool } // clientGen is a func that creates a control plane client. @@ -273,6 +299,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo e := sys.Engine.Get() store := sys.StateStore.Get() dialer := sys.Dialer.Get() + _ = sys.MagicSock.Get() // or panic pm, err := newProfileManager(store, logf) if err != nil { @@ -293,24 +320,27 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo ctx, cancel := context.WithCancel(context.Background()) portpoll := new(portlist.Poller) + clock := tstime.StdClock{} b := &LocalBackend{ - ctx: ctx, - ctxCancel: cancel, - logf: logf, - keyLogf: logger.LogOnChange(logf, 5*time.Minute, time.Now), - statsLogf: logger.LogOnChange(logf, 5*time.Minute, time.Now), - sys: sys, - e: e, - dialer: dialer, - store: store, - pm: pm, - backendLogID: logID, - state: ipn.NoState, - portpoll: portpoll, - em: newExpiryManager(logf), - gotPortPollRes: make(chan struct{}), - loginFlags: loginFlags, + ctx: ctx, + ctxCancel: cancel, + logf: logf, + keyLogf: logger.LogOnChange(logf, 5*time.Minute, clock.Now), + statsLogf: logger.LogOnChange(logf, 5*time.Minute, clock.Now), + sys: sys, + e: e, + dialer: dialer, + store: store, + pm: pm, + backendLogID: logID, + state: ipn.NoState, + portpoll: portpoll, + em: newExpiryManager(logf), + gotPortPollRes: make(chan struct{}), + loginFlags: loginFlags, + clock: clock, + activeWatchSessions: make(set.Set[string]), } netMon := sys.NetMon.Get() @@ -334,7 +364,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo b.prevIfState = netMon.InterfaceState() // Call our linkChange code once with the current state, and // then also whenever it changes: - b.linkChange(false, netMon.InterfaceState()) + b.linkChange(&netmon.ChangeDelta{New: netMon.InterfaceState()}) b.unregisterNetMon = netMon.RegisterChangeCallback(b.linkChange) b.unregisterHealthWatch = health.RegisterWatcher(b.onHealthChange) @@ -348,7 +378,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo for _, component := range debuggableComponents { key := componentStateKey(component) if ut, err := ipn.ReadStoreInt(pm.Store(), key); err == nil { - if until := time.Unix(ut, 0); until.After(time.Now()) { + if until := time.Unix(ut, 0); until.After(b.clock.Now()) { // conditional to avoid log spam at start when off b.SetComponentDebugLogging(component, until) } @@ -360,7 +390,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo type componentLogState struct { until time.Time - timer *time.Timer // if non-nil, the AfterFunc to disable it + timer tstime.TimerController // if non-nil, the AfterFunc to disable it } var debuggableComponents = []string{ @@ -386,11 +416,7 @@ func (b *LocalBackend) SetComponentDebugLogging(component string, until time.Tim var setEnabled func(bool) switch component { case "magicsock": - mc, err := b.magicConn() - if err != nil { - return err - } - setEnabled = mc.SetDebugLoggingEnabled + setEnabled = b.magicConn().SetDebugLoggingEnabled case "sockstats": if b.sockstatLogger != nil { setEnabled = func(v bool) { @@ -413,7 +439,7 @@ func (b *LocalBackend) SetComponentDebugLogging(component string, until time.Tim return t.Unix() } ipn.PutStoreInt(b.store, componentStateKey(component), timeUnixOrZero(until)) - now := time.Now() + now := b.clock.Now() on := now.Before(until) setEnabled(on) var onFor time.Duration @@ -428,7 +454,7 @@ func (b *LocalBackend) SetComponentDebugLogging(component string, until time.Tim } newSt := componentLogState{until: until} if on { - newSt.timer = time.AfterFunc(onFor, func() { + newSt.timer = b.clock.AfterFunc(onFor, func() { // Turn off logging after the timer fires, as long as the state is // unchanged when the timer actually fires. b.mu.Lock() @@ -450,7 +476,7 @@ func (b *LocalBackend) GetComponentDebugLogging(component string) time.Time { b.mu.Lock() defer b.mu.Unlock() - now := time.Now() + now := b.clock.Now() ls := b.componentLogUntil[component] if ls.until.IsZero() || ls.until.Before(now) { return time.Time{} @@ -486,24 +512,28 @@ func (b *LocalBackend) SetDirectFileDoFinalRename(v bool) { b.directFileDoFinalRename = v } +// pauseOrResumeControlClientLocked pauses b.cc if there is no network available +// or if the LocalBackend is in Stopped state with a valid NetMap. In all other +// cases, it unpauses it. It is a no-op if b.cc is nil. +// // b.mu must be held. -func (b *LocalBackend) maybePauseControlClientLocked() { +func (b *LocalBackend) pauseOrResumeControlClientLocked() { if b.cc == nil { return } networkUp := b.prevIfState.AnyInterfaceUp() - b.cc.SetPaused((b.state == ipn.Stopped && b.netMap != nil) || !networkUp) + b.cc.SetPaused((b.state == ipn.Stopped && b.netMap != nil) || (!networkUp && !testenv.InTest())) } // linkChange is our network monitor callback, called whenever the network changes. -// major is whether ifst is different than earlier. -func (b *LocalBackend) linkChange(major bool, ifst *interfaces.State) { +func (b *LocalBackend) linkChange(delta *netmon.ChangeDelta) { b.mu.Lock() defer b.mu.Unlock() + ifst := delta.New hadPAC := b.prevIfState.HasPAC() b.prevIfState = ifst - b.maybePauseControlClientLocked() + b.pauseOrResumeControlClientLocked() // If the PAC-ness of the network changed, reconfig wireguard+route to // add/remove subnets. @@ -522,7 +552,7 @@ func (b *LocalBackend) linkChange(major bool, ifst *interfaces.State) { b.updateFilterLocked(b.netMap, b.pm.CurrentPrefs()) if peerAPIListenAsync && b.netMap != nil && b.state == ipn.Running { - want := len(b.netMap.Addresses) + want := b.netMap.GetAddresses().Len() if len(b.peerAPIListeners) < want { b.logf("linkChange: peerAPIListeners too low; trying again") go b.initPeerAPIListener() @@ -552,7 +582,14 @@ func (b *LocalBackend) Shutdown() { b.mu.Unlock() ctx, cancel := context.WithTimeout(b.ctx, 5*time.Second) defer cancel() - b.LogoutSync(ctx) // best effort + t0 := time.Now() + err := b.Logout(ctx) // best effort + td := time.Since(t0).Round(time.Millisecond) + if err != nil { + b.logf("failed to log out ephemeral node on shutdown after %v: %v", td, err) + } else { + b.logf("logged out ephemeral node on shutdown") + } b.mu.Lock() } cc := b.cc @@ -624,25 +661,19 @@ func (b *LocalBackend) StatusWithoutPeers() *ipnstate.Status { // UpdateStatus implements ipnstate.StatusUpdater. func (b *LocalBackend) UpdateStatus(sb *ipnstate.StatusBuilder) { - b.e.UpdateStatus(sb) - var extraLocked func(*ipnstate.StatusBuilder) - if sb.WantPeers { - extraLocked = b.populatePeerStatusLocked - } - b.updateStatus(sb, extraLocked) -} + b.e.UpdateStatus(sb) // does wireguard + magicsock status -// updateStatus populates sb with status. -// -// extraLocked, if non-nil, is called while b.mu is still held. -func (b *LocalBackend) updateStatus(sb *ipnstate.StatusBuilder, extraLocked func(*ipnstate.StatusBuilder)) { b.mu.Lock() defer b.mu.Unlock() + sb.MutateStatus(func(s *ipnstate.Status) { s.Version = version.Long() s.TUN = !b.sys.IsNetstack() s.BackendState = b.state.String() s.AuthURL = b.authURLSticky + if prefs := b.pm.CurrentPrefs(); prefs.Valid() && prefs.AutoUpdate().Check { + s.ClientVersion = b.lastClientVersion + } if err := health.OverallError(); err != nil { switch e := err.(type) { case multierr.Error: @@ -675,32 +706,58 @@ func (b *LocalBackend) updateStatus(sb *ipnstate.StatusBuilder, extraLocked func if !prefs.ExitNodeID().IsZero() { if exitPeer, ok := b.netMap.PeerWithStableID(prefs.ExitNodeID()); ok { var online = false - if exitPeer.Online != nil { - online = *exitPeer.Online + if v := exitPeer.Online(); v != nil { + online = *v } s.ExitNodeStatus = &ipnstate.ExitNodeStatus{ ID: prefs.ExitNodeID(), Online: online, - TailscaleIPs: exitPeer.Addresses, + TailscaleIPs: exitPeer.Addresses().AsSlice(), } } } } } }) + + var tailscaleIPs []netip.Addr + if b.netMap != nil { + addrs := b.netMap.GetAddresses() + for i := range addrs.LenIter() { + if addr := addrs.At(i); addr.IsSingleIP() { + sb.AddTailscaleIP(addr.Addr()) + tailscaleIPs = append(tailscaleIPs, addr.Addr()) + } + } + } + sb.MutateSelfStatus(func(ss *ipnstate.PeerStatus) { + ss.OS = version.OS() ss.Online = health.GetInPollNetMap() if b.netMap != nil { ss.InNetworkMap = true - ss.HostName = b.netMap.Hostinfo.Hostname + if hi := b.netMap.SelfNode.Hostinfo(); hi.Valid() { + ss.HostName = hi.Hostname() + } ss.DNSName = b.netMap.Name - ss.UserID = b.netMap.User - if sn := b.netMap.SelfNode; sn != nil { + ss.UserID = b.netMap.User() + if sn := b.netMap.SelfNode; sn.Valid() { peerStatusFromNode(ss, sn) - if c := sn.Capabilities; len(c) > 0 { - ss.Capabilities = append([]string(nil), c...) + if c := sn.Capabilities(); c.Len() > 0 { + ss.Capabilities = c.AsSlice() } + if cm := sn.CapMap(); cm.Len() > 0 { + ss.CapMap = make(tailcfg.NodeCapMap, sn.CapMap().Len()) + cm.Range(func(k tailcfg.NodeCapability, v views.Slice[tailcfg.RawMessage]) bool { + ss.CapMap[k] = v.AsSlice() + return true + }) + } + } + for _, addr := range tailscaleIPs { + ss.TailscaleIPs = append(ss.TailscaleIPs, addr) } + } else { ss.HostName, _ = os.Hostname() } @@ -711,8 +768,8 @@ func (b *LocalBackend) updateStatus(sb *ipnstate.StatusBuilder, extraLocked func // TODO: hostinfo, and its networkinfo // TODO: EngineStatus copy (and deprecate it?) - if extraLocked != nil { - extraLocked(sb) + if sb.WantPeers { + b.populatePeerStatusLocked(sb) } } @@ -724,30 +781,33 @@ func (b *LocalBackend) populatePeerStatusLocked(sb *ipnstate.StatusBuilder) { sb.AddUser(id, up) } exitNodeID := b.pm.CurrentPrefs().ExitNodeID() - for _, p := range b.netMap.Peers { + for _, p := range b.peers { var lastSeen time.Time - if p.LastSeen != nil { - lastSeen = *p.LastSeen + if p.LastSeen() != nil { + lastSeen = *p.LastSeen() } - var tailscaleIPs = make([]netip.Addr, 0, len(p.Addresses)) - for _, addr := range p.Addresses { + var tailscaleIPs = make([]netip.Addr, 0, p.Addresses().Len()) + for i := range p.Addresses().LenIter() { + addr := p.Addresses().At(i) if addr.IsSingleIP() && tsaddr.IsTailscaleIP(addr.Addr()) { tailscaleIPs = append(tailscaleIPs, addr.Addr()) } } + online := p.Online() ps := &ipnstate.PeerStatus{ - InNetworkMap: true, - UserID: p.User, - TailscaleIPs: tailscaleIPs, - HostName: p.Hostinfo.Hostname(), - DNSName: p.Name, - OS: p.Hostinfo.OS(), - KeepAlive: p.KeepAlive, - LastSeen: lastSeen, - Online: p.Online != nil && *p.Online, - ShareeNode: p.Hostinfo.ShareeNode(), - ExitNode: p.StableID != "" && p.StableID == exitNodeID, - SSH_HostKeys: p.Hostinfo.SSH_HostKeys().AsSlice(), + InNetworkMap: true, + UserID: p.User(), + AltSharerUserID: p.Sharer(), + TailscaleIPs: tailscaleIPs, + HostName: p.Hostinfo().Hostname(), + DNSName: p.Name(), + OS: p.Hostinfo().OS(), + LastSeen: lastSeen, + Online: online != nil && *online, + ShareeNode: p.Hostinfo().ShareeNode(), + ExitNode: p.StableID() != "" && p.StableID() == exitNodeID, + SSH_HostKeys: p.Hostinfo().SSH_HostKeys().AsSlice(), + Location: p.Hostinfo().Location(), } peerStatusFromNode(ps, p) @@ -758,29 +818,30 @@ func (b *LocalBackend) populatePeerStatusLocked(sb *ipnstate.StatusBuilder) { if u := peerAPIURL(nodeIP(p, netip.Addr.Is6), p6); u != "" { ps.PeerAPIURL = append(ps.PeerAPIURL, u) } - sb.AddPeer(p.Key, ps) + sb.AddPeer(p.Key(), ps) } } // peerStatusFromNode copies fields that exist in the Node struct for // current node and peers into the provided PeerStatus. -func peerStatusFromNode(ps *ipnstate.PeerStatus, n *tailcfg.Node) { - ps.ID = n.StableID - ps.Created = n.Created - ps.ExitNodeOption = tsaddr.ContainsExitRoutes(n.AllowedIPs) - if n.Tags != nil { - v := views.SliceOf(n.Tags) +func peerStatusFromNode(ps *ipnstate.PeerStatus, n tailcfg.NodeView) { + ps.PublicKey = n.Key() + ps.ID = n.StableID() + ps.Created = n.Created() + ps.ExitNodeOption = tsaddr.ContainsExitRoutes(n.AllowedIPs()) + if n.Tags().Len() != 0 { + v := n.Tags() ps.Tags = &v } - if n.PrimaryRoutes != nil { - v := views.IPPrefixSliceOf(n.PrimaryRoutes) + if n.PrimaryRoutes().Len() != 0 { + v := n.PrimaryRoutes() ps.PrimaryRoutes = &v } - if n.Expired { + if n.Expired() { ps.Expired = true } - if t := n.KeyExpiry; !t.IsZero() { + if t := n.KeyExpiry(); !t.IsZero() { t = t.Round(time.Second) ps.KeyExpiry = &t } @@ -789,39 +850,51 @@ func peerStatusFromNode(ps *ipnstate.PeerStatus, n *tailcfg.Node) { // WhoIs reports the node and user who owns the node with the given IP:port. // If the IP address is a Tailscale IP, the provided port may be 0. // If ok == true, n and u are valid. -func (b *LocalBackend) WhoIs(ipp netip.AddrPort) (n *tailcfg.Node, u tailcfg.UserProfile, ok bool) { +func (b *LocalBackend) WhoIs(ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) { + var zero tailcfg.NodeView b.mu.Lock() defer b.mu.Unlock() - n, ok = b.nodeByAddr[ipp.Addr()] + nid, ok := b.nodeByAddr[ipp.Addr()] if !ok { var ip netip.Addr if ipp.Port() != 0 { - ip, ok = b.e.WhoIsIPPort(ipp) + ip, ok = b.sys.ProxyMapper().WhoIsIPPort(ipp) } if !ok { - return nil, u, false + return zero, u, false } - n, ok = b.nodeByAddr[ip] + nid, ok = b.nodeByAddr[ip] if !ok { - return nil, u, false + return zero, u, false + } + } + if b.netMap == nil { + return zero, u, false + } + n, ok = b.peers[nid] + if !ok { + // Check if this the self-node, which would not appear in peers. + if !b.netMap.SelfNode.Valid() || nid != b.netMap.SelfNode.ID() { + return zero, u, false } + n = b.netMap.SelfNode } - u, ok = b.netMap.UserProfiles[n.User] + u, ok = b.netMap.UserProfiles[n.User()] if !ok { - return nil, u, false + return zero, u, false } return n, u, true } // PeerCaps returns the capabilities that remote src IP has to // ths current node. -func (b *LocalBackend) PeerCaps(src netip.Addr) []string { +func (b *LocalBackend) PeerCaps(src netip.Addr) tailcfg.PeerCapMap { b.mu.Lock() defer b.mu.Unlock() return b.peerCapsLocked(src) } -func (b *LocalBackend) peerCapsLocked(src netip.Addr) []string { +func (b *LocalBackend) peerCapsLocked(src netip.Addr) tailcfg.PeerCapMap { if b.netMap == nil { return nil } @@ -829,34 +902,32 @@ func (b *LocalBackend) peerCapsLocked(src netip.Addr) []string { if filt == nil { return nil } - for _, a := range b.netMap.Addresses { + addrs := b.netMap.GetAddresses() + for i := range addrs.LenIter() { + a := addrs.At(i) if !a.IsSingleIP() { continue } dst := a.Addr() if dst.BitLen() == src.BitLen() { // match on family - return filt.AppendCaps(nil, src, dst) + return filt.CapsWithValues(src, dst) } } return nil } -// SetDecompressor sets a decompression function, which must be a zstd -// reader. -// -// This exists because the iOS/Mac NetworkExtension is very resource -// constrained, and the zstd package is too heavy to fit in the -// constrained RSS limit. -func (b *LocalBackend) SetDecompressor(fn func() (controlclient.Decompressor, error)) { - b.newDecompressor = fn -} - -// setClientStatus is the callback invoked by the control client whenever it posts a new status. +// SetControlClientStatus is the callback invoked by the control client whenever it posts a new status. // Among other things, this is where we update the netmap, packet filters, DNS and DERP maps. -func (b *LocalBackend) setClientStatus(st controlclient.Status) { +func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st controlclient.Status) { + b.mu.Lock() + if b.cc != c { + b.logf("Ignoring SetControlClientStatus from old client") + b.mu.Unlock() + return + } // The following do not depend on any data for which we need to lock b. if st.Err != nil { - // TODO(crawshaw): display in the UI. + b.mu.Unlock() if errors.Is(st.Err, io.EOF) { b.logf("[v1] Received error: EOF") return @@ -873,11 +944,9 @@ func (b *LocalBackend) setClientStatus(st controlclient.Status) { // Track the number of calls currCall := b.numClientStatusCalls.Add(1) - b.mu.Lock() - // Handle node expiry in the netmap if st.NetMap != nil { - now := time.Now() + now := b.clock.Now() b.em.flagExpiredPeers(st.NetMap, now) // Always stop the existing netmap timer if we have a netmap; @@ -897,7 +966,7 @@ func (b *LocalBackend) setClientStatus(st controlclient.Status) { nextExpiry := b.em.nextPeerExpiry(st.NetMap, now) if !nextExpiry.IsZero() { tmrDuration := nextExpiry.Sub(now) + 10*time.Second - b.nmExpiryTimer = time.AfterFunc(tmrDuration, func() { + b.nmExpiryTimer = b.clock.AfterFunc(tmrDuration, func() { // Skip if the world has moved on past the // saved call (e.g. if we race stopping this // timer). @@ -910,7 +979,7 @@ func (b *LocalBackend) setClientStatus(st controlclient.Status) { // Call ourselves with the current status again; the logic in // setClientStatus will take care of updating the expired field // of peers in the netmap. - b.setClientStatus(st) + b.SetControlClientStatus(c, st) }) } } @@ -919,7 +988,7 @@ func (b *LocalBackend) setClientStatus(st controlclient.Status) { keyExpiryExtended := false if st.NetMap != nil { wasExpired := b.keyExpired - isExpired := !st.NetMap.Expiry.IsZero() && st.NetMap.Expiry.Before(time.Now()) + isExpired := !st.NetMap.Expiry.IsZero() && st.NetMap.Expiry.Before(b.clock.Now()) if wasExpired && !isExpired { keyExpiryExtended = true } @@ -932,7 +1001,7 @@ func (b *LocalBackend) setClientStatus(st controlclient.Status) { b.blockEngineUpdates(false) } - if st.LoginFinished != nil && wasBlocked { + if st.LoginFinished() && wasBlocked { // Auth completed, unblock the engine b.blockEngineUpdates(false) b.authReconfig() @@ -942,20 +1011,6 @@ func (b *LocalBackend) setClientStatus(st controlclient.Status) { // Lock b once and do only the things that require locking. b.mu.Lock() - if st.LogoutFinished != nil { - if p := b.pm.CurrentPrefs(); !p.Persist().Valid() || p.Persist().LoginName() == "" { - b.mu.Unlock() - return - } - if err := b.pm.DeleteProfile(b.pm.CurrentProfile().ID); err != nil { - b.logf("error deleting profile: %v", err) - } - if err := b.resetForProfileChangeLockedOnEntry(); err != nil { - b.logf("resetForProfileChangeLockedOnEntry err: %v", err) - } - return - } - prefsChanged := false prefs := b.pm.CurrentPrefs().AsStruct() netMap := b.netMap @@ -970,8 +1025,8 @@ func (b *LocalBackend) setClientStatus(st controlclient.Status) { prefs.ControlURL = prefs.ControlURLOrDefault() prefsChanged = true } - if st.Persist != nil && st.Persist.Valid() { - if !prefs.Persist.View().Equals(*st.Persist) { + if st.Persist.Valid() { + if !prefs.Persist.View().Equals(st.Persist) { prefsChanged = true prefs.Persist = st.Persist.AsStruct() } @@ -980,7 +1035,7 @@ func (b *LocalBackend) setClientStatus(st controlclient.Status) { b.authURL = st.URL b.authURLSticky = st.URL } - if wasBlocked && st.LoginFinished != nil { + if wasBlocked && st.LoginFinished() { // Interactive login finished successfully (URL visited). // After an interactive login, the user always wants // WantRunning. @@ -990,19 +1045,14 @@ func (b *LocalBackend) setClientStatus(st controlclient.Status) { prefs.WantRunning = true prefs.LoggedOut = false } - if findExitNodeIDLocked(prefs, st.NetMap) { + if setExitNodeID(prefs, st.NetMap) { prefsChanged = true } // Perform all mutations of prefs based on the netmap here. - if st.NetMap != nil { - if b.updatePersistFromNetMapLocked(st.NetMap, prefs) { - prefsChanged = true - } - } - // Prefs will be written out if stale; this is not safe unless locked or cloned. if prefsChanged { - if err := b.pm.SetPrefs(prefs.View()); err != nil { + // Prefs will be written out if stale; this is not safe unless locked or cloned. + if err := b.pm.SetPrefs(prefs.View(), st.NetMap.MagicDNSSuffix()); err != nil { b.logf("Failed to save new controlclient state: %v", err) } } @@ -1014,7 +1064,7 @@ func (b *LocalBackend) setClientStatus(st controlclient.Status) { // Perform all reconfiguration based on the netmap here. if st.NetMap != nil { - b.capTailnetLock = hasCapability(st.NetMap, tailcfg.CapabilityTailnetLockAlpha) + b.capTailnetLock = hasCapability(st.NetMap, tailcfg.CapabilityTailnetLock) b.mu.Unlock() // respect locking rules for tkaSyncIfNeeded if err := b.tkaSyncIfNeeded(st.NetMap, prefs.View()); err != nil { @@ -1053,7 +1103,7 @@ func (b *LocalBackend) setClientStatus(st controlclient.Status) { b.mu.Lock() prefs.WantRunning = false p := prefs.View() - if err := b.pm.SetPrefs(p); err != nil { + if err := b.pm.SetPrefs(p, st.NetMap.MagicDNSSuffix()); err != nil { b.logf("Failed to save new controlclient state: %v", err) } b.mu.Unlock() @@ -1070,7 +1120,7 @@ func (b *LocalBackend) setClientStatus(st controlclient.Status) { } b.e.SetNetworkMap(st.NetMap) - b.e.SetDERPMap(st.NetMap.DERPMap) + b.magicConn().SetDERPMap(st.NetMap.DERPMap) // Update our cached DERP map dnsfallback.UpdateCache(st.NetMap.DERPMap, b.logf) @@ -1089,9 +1139,94 @@ func (b *LocalBackend) setClientStatus(st controlclient.Status) { b.authReconfig() } -// findExitNodeIDLocked updates prefs to reference an exit node by ID, rather +var _ controlclient.NetmapDeltaUpdater = (*LocalBackend)(nil) + +// UpdateNetmapDelta implements controlclient.NetmapDeltaUpdater. +func (b *LocalBackend) UpdateNetmapDelta(muts []netmap.NodeMutation) (handled bool) { + if !b.magicConn().UpdateNetmapDelta(muts) { + return false + } + + var notify *ipn.Notify // non-nil if we need to send a Notify + defer func() { + if notify != nil { + b.send(*notify) + } + }() + + b.mu.Lock() + defer b.mu.Unlock() + if !b.updateNetmapDeltaLocked(muts) { + return false + } + + if b.netMap != nil && mutationsAreWorthyOfTellingIPNBus(muts) { + nm := ptr.To(*b.netMap) // shallow clone + nm.Peers = make([]tailcfg.NodeView, 0, len(b.peers)) + for _, p := range b.peers { + nm.Peers = append(nm.Peers, p) + } + slices.SortFunc(nm.Peers, func(a, b tailcfg.NodeView) int { + return cmpx.Compare(a.ID(), b.ID()) + }) + notify = &ipn.Notify{NetMap: nm} + } else if testenv.InTest() { + // In tests, send an empty Notify as a wake-up so end-to-end + // integration tests in another repo can check on the status of + // LocalBackend after processing deltas. + notify = new(ipn.Notify) + } + return true +} + +// mutationsAreWorthyOfTellingIPNBus reports whether any mutation type in muts is +// worthy of spamming the IPN bus (the Windows & Mac GUIs, basically) to tell them +// about the update. +func mutationsAreWorthyOfTellingIPNBus(muts []netmap.NodeMutation) bool { + for _, m := range muts { + switch m.(type) { + case netmap.NodeMutationLastSeen, + netmap.NodeMutationOnline: + // The GUI clients might render peers differently depending on whether + // they're online. + return true + } + } + return false +} + +func (b *LocalBackend) updateNetmapDeltaLocked(muts []netmap.NodeMutation) (handled bool) { + if b.netMap == nil || len(b.peers) == 0 { + return false + } + + // Locally cloned mutable nodes, to avoid calling AsStruct (clone) + // multiple times on a node if it's mutated multiple times in this + // call (e.g. its endpoints + online status both change) + var mutableNodes map[tailcfg.NodeID]*tailcfg.Node + + for _, m := range muts { + n, ok := mutableNodes[m.NodeIDBeingMutated()] + if !ok { + nv, ok := b.peers[m.NodeIDBeingMutated()] + if !ok { + // TODO(bradfitz): unexpected metric? + return false + } + n = nv.AsStruct() + mak.Set(&mutableNodes, nv.ID(), n) + } + m.Apply(n) + } + for nid, n := range mutableNodes { + b.peers[nid] = n.View() + } + return true +} + +// setExitNodeID updates prefs to reference an exit node by ID, rather // than by IP. It returns whether prefs was mutated. -func findExitNodeIDLocked(prefs *ipn.Prefs, nm *netmap.NetworkMap) (prefsChanged bool) { +func setExitNodeID(prefs *ipn.Prefs, nm *netmap.NetworkMap) (prefsChanged bool) { if nm == nil { // No netmap, can't resolve anything. return false @@ -1110,13 +1245,14 @@ func findExitNodeIDLocked(prefs *ipn.Prefs, nm *netmap.NetworkMap) (prefsChanged } for _, peer := range nm.Peers { - for _, addr := range peer.Addresses { + for i := range peer.Addresses().LenIter() { + addr := peer.Addresses().At(i) if !addr.IsSingleIP() || addr.Addr() != prefs.ExitNodeIP { continue } // Found the node being referenced, upgrade prefs to // reference it directly for next time. - prefs.ExitNodeID = peer.StableID + prefs.ExitNodeID = peer.StableID() prefs.ExitNodeIP = netip.Addr{} return true } @@ -1214,6 +1350,27 @@ func (b *LocalBackend) SetControlClientGetterForTesting(newControlClient func(co b.ccGen = newControlClient } +// NodeViewByIDForTest returns the state of the node with the given ID +// for integration tests in another repo. +func (b *LocalBackend) NodeViewByIDForTest(id tailcfg.NodeID) (_ tailcfg.NodeView, ok bool) { + b.mu.Lock() + defer b.mu.Unlock() + n, ok := b.peers[id] + return n, ok +} + +// PeersForTest returns all the current peers, sorted by Node.ID, +// for integration tests in another repo. +func (b *LocalBackend) PeersForTest() []tailcfg.NodeView { + b.mu.Lock() + defer b.mu.Unlock() + ret := xmaps.Values(b.peers) + slices.SortFunc(ret, func(a, b tailcfg.NodeView) int { + return cmpx.Compare(a.ID(), b.ID()) + }) + return ret +} + func (b *LocalBackend) getNewControlClientFunc() clientGen { b.mu.Lock() defer b.mu.Unlock() @@ -1264,10 +1421,6 @@ func (b *LocalBackend) startIsNoopLocked(opts ipn.Options) bool { // actually a supported operation (it should be, but it's very unclear // from the following whether or not that is a safe transition). func (b *LocalBackend) Start(opts ipn.Options) error { - if opts.LegacyMigrationPrefs == nil && !b.pm.CurrentPrefs().Valid() { - return errors.New("no prefs provided") - } - if opts.LegacyMigrationPrefs != nil { b.logf("Start: %v", opts.LegacyMigrationPrefs.Pretty()) } else { @@ -1275,6 +1428,11 @@ func (b *LocalBackend) Start(opts ipn.Options) error { } b.mu.Lock() + if opts.LegacyMigrationPrefs == nil && !b.pm.CurrentPrefs().Valid() { + b.mu.Unlock() + return errors.New("no prefs provided") + } + if opts.UpdatePrefs != nil { if err := b.checkPrefsLocked(opts.UpdatePrefs); err != nil { b.mu.Unlock() @@ -1293,7 +1451,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error { // but meanwhile we can make Start cheaper here for such a // case and not restart the world (which takes a few seconds). // Instead, just send a notify with the state that iOS needs. - if b.startIsNoopLocked(opts) && profileID == b.lastProfileID { + if b.startIsNoopLocked(opts) && profileID == b.lastProfileID && profileID != "" { b.logf("Start: already running; sending notify") nm := b.netMap state := b.state @@ -1314,16 +1472,14 @@ func (b *LocalBackend) Start(opts ipn.Options) error { hostinfo.Userspace.Set(b.sys.IsNetstack()) hostinfo.UserspaceRouter.Set(b.sys.IsNetstackRouter()) - if b.cc != nil { - // TODO(apenwarr): avoid the need to reinit controlclient. - // This will trigger a full relogin/reconfigure cycle every - // time a Handle reconnects to the backend. Ideally, we - // would send the new Prefs and everything would get back - // into sync with the minimal changes. But that's not how it - // is right now, which is a sign that the code is still too - // complicated. - b.resetControlClientLockedAsync() - } + // TODO(apenwarr): avoid the need to reinit controlclient. + // This will trigger a full relogin/reconfigure cycle every + // time a Handle reconnects to the backend. Ideally, we + // would send the new Prefs and everything would get back + // into sync with the minimal changes. But that's not how it + // is right now, which is a sign that the code is still too + // complicated. + prevCC := b.resetControlClientLocked() httpTestClient := b.httpTestClient if b.hostinfo != nil { @@ -1342,7 +1498,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error { newPrefs := opts.UpdatePrefs.Clone() newPrefs.Persist = oldPrefs.Persist().AsStruct() pv := newPrefs.View() - if err := b.pm.SetPrefs(pv); err != nil { + if err := b.pm.SetPrefs(pv, b.netMap.MagicDNSSuffix()); err != nil { b.logf("failed to save UpdatePrefs state: %v", err) } b.setAtomicValuesFromPrefsLocked(pv) @@ -1352,6 +1508,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error { wantRunning := prefs.WantRunning() if wantRunning { if err := b.initMachineKeyLocked(); err != nil { + b.mu.Unlock() return fmt.Errorf("initMachineKeyLocked: %w", err) } } @@ -1380,19 +1537,19 @@ func (b *LocalBackend) Start(opts ipn.Options) error { // prevent it from restarting our map poll // HTTP request (via doSetHostinfoFilterServices > // cli.SetHostinfo). In practice this is very quick. - t0 := time.Now() - timer := time.NewTimer(time.Second) + t0 := b.clock.Now() + timer, timerChannel := b.clock.NewTimer(time.Second) select { case <-b.gotPortPollRes: - b.logf("[v1] got initial portlist info in %v", time.Since(t0).Round(time.Millisecond)) + b.logf("[v1] got initial portlist info in %v", b.clock.Since(t0).Round(time.Millisecond)) timer.Stop() - case <-timer.C: + case <-timerChannel: b.logf("timeout waiting for initial portlist") } }) } - discoPublic := b.e.DiscoPublicKey() + discoPublic := b.magicConn().DiscoPublicKey() var err error @@ -1402,6 +1559,10 @@ func (b *LocalBackend) Start(opts ipn.Options) error { debugFlags = append([]string{"netstack"}, debugFlags...) } + if prevCC != nil { + prevCC.Shutdown() + } + // TODO(apenwarr): The only way to change the ServerURL is to // re-run b.Start(), because this is the only place we create a // new controlclient. SetPrefs() allows you to overwrite ServerURL, @@ -1413,8 +1574,6 @@ func (b *LocalBackend) Start(opts ipn.Options) error { ServerURL: serverURL, AuthKey: opts.AuthKey, Hostinfo: hostinfo, - KeepAlive: true, - NewDecompressor: b.newDecompressor, HTTPTestClient: httpTestClient, DiscoPublicKey: discoPublic, DebugFlags: debugFlags, @@ -1424,9 +1583,10 @@ func (b *LocalBackend) Start(opts ipn.Options) error { OnClientVersion: b.onClientVersion, OnControlTime: b.em.onControlTime, Dialer: b.Dialer(), - Status: b.setClientStatus, + Observer: b, C2NHandler: http.HandlerFunc(b.handleC2N), DialPlan: &b.dialPlan, // pointer because it can't be copied + ControlKnobs: b.sys.ControlKnobs(), // Don't warn about broken Linux IP forwarding when // netstack is being used. @@ -1437,6 +1597,13 @@ func (b *LocalBackend) Start(opts ipn.Options) error { } b.mu.Lock() + // Even though we reset b.cc above, we might have raced with + // another Start() call. If so, shut down the previous one again + // as we do not know if it was created with the same options. + prevCC = b.resetControlClientLocked() + if prevCC != nil { + defer prevCC.Shutdown() // must be called after b.mu is unlocked + } b.cc = cc b.ccAuto, _ = cc.(*controlclient.Auto) endpoints := b.endpoints @@ -1460,7 +1627,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error { } cc.SetTKAHead(tkaHead) - b.e.SetNetInfoCallback(b.setNetInfo) + b.magicConn().SetNetInfoCallback(b.setNetInfo) blid := b.backendLogID.String() b.logf("Backend: logs: be:%v fe:%v", blid, opts.FrontendLogID) @@ -1491,7 +1658,7 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.P // quite hard to debug, so save yourself the trouble. var ( haveNetmap = netMap != nil - addrs []netip.Prefix + addrs views.Slice[netip.Prefix] packetFilter []filter.Match localNetsB netipx.IPSetBuilder logNetsB netipx.IPSetBuilder @@ -1502,13 +1669,13 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.P logNetsB.AddPrefix(tsaddr.TailscaleULARange()) logNetsB.RemovePrefix(tsaddr.ChromeOSVMRange()) if haveNetmap { - addrs = netMap.Addresses - for _, p := range addrs { - localNetsB.AddPrefix(p) + addrs = netMap.GetAddresses() + for i := range addrs.LenIter() { + localNetsB.AddPrefix(addrs.At(i)) } packetFilter = netMap.PacketFilter - if packetFilterPermitsUnlockedNodes(netMap.Peers, packetFilter) { + if packetFilterPermitsUnlockedNodes(b.peers, packetFilter) { err := errors.New("server sent invalid packet filter permitting traffic to unlocked nodes; rejecting all packets for safety") warnInvalidUnsignedNodes.Set(err) packetFilter = nil @@ -1556,7 +1723,7 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.P changed := deephash.Update(&b.filterHash, &struct { HaveNetmap bool - Addrs []netip.Prefix + Addrs views.Slice[netip.Prefix] FilterMatch []filter.Match LocalNets []netipx.IPRange LogNets []netipx.IPRange @@ -1593,16 +1760,16 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.P // // If this reports true, the packet filter is invalid (the server is either broken // or malicious) and should be ignored for safety. -func packetFilterPermitsUnlockedNodes(peers []*tailcfg.Node, packetFilter []filter.Match) bool { +func packetFilterPermitsUnlockedNodes(peers map[tailcfg.NodeID]tailcfg.NodeView, packetFilter []filter.Match) bool { var b netipx.IPSetBuilder var numUnlocked int for _, p := range peers { - if !p.UnsignedPeerAPIOnly { + if !p.UnsignedPeerAPIOnly() { continue } numUnlocked++ - for _, a := range p.AllowedIPs { // not only addresses! - b.AddPrefix(a) + for i := range p.AllowedIPs().LenIter() { // not only addresses! + b.AddPrefix(p.AllowedIPs().At(i)) } } if numUnlocked == 0 { @@ -1758,64 +1925,17 @@ func shrinkDefaultRoute(route netip.Prefix, localInterfaceRoutes *netipx.IPSet, return b.IPSet() } -// dnsCIDRsEqual determines whether two CIDR lists are equal -// for DNS map construction purposes (that is, only the first entry counts). -func dnsCIDRsEqual(newAddr, oldAddr []netip.Prefix) bool { - if len(newAddr) != len(oldAddr) { - return false - } - if len(newAddr) == 0 || newAddr[0] == oldAddr[0] { - return true - } - return false -} - -// dnsMapsEqual determines whether the new and the old network map -// induce the same DNS map. It does so without allocating memory, -// at the expense of giving false negatives if peers are reordered. -func dnsMapsEqual(new, old *netmap.NetworkMap) bool { - if (old == nil) != (new == nil) { - return false - } - if old == nil && new == nil { - return true - } - - if len(new.Peers) != len(old.Peers) { - return false - } - - if new.Name != old.Name { - return false - } - if !dnsCIDRsEqual(new.Addresses, old.Addresses) { - return false - } - - for i, newPeer := range new.Peers { - oldPeer := old.Peers[i] - if newPeer.Name != oldPeer.Name { - return false - } - if !dnsCIDRsEqual(newPeer.Addresses, oldPeer.Addresses) { - return false - } - } - - return true -} - // readPoller is a goroutine that receives service lists from // b.portpoll and propagates them into the controlclient's HostInfo. func (b *LocalBackend) readPoller() { isFirst := true - ticker := time.NewTicker(portlist.PollInterval()) + ticker, tickerChannel := b.clock.NewTicker(portlist.PollInterval()) defer ticker.Stop() initChan := make(chan struct{}) close(initChan) for { select { - case <-ticker.C: + case <-tickerChannel: case <-b.ctx.Done(): return case <-initChan: @@ -1895,6 +2015,8 @@ func (b *LocalBackend) ResendHostinfoIfNeeded() { func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWatchOpt, onWatchAdded func(), fn func(roNotify *ipn.Notify) (keepGoing bool)) { ch := make(chan *ipn.Notify, 128) + sessionID := rands.HexString(16) + origFn := fn if mask&ipn.NotifyNoPrivateKeys != 0 { fn = func(n *ipn.Notify) bool { @@ -1916,10 +2038,13 @@ func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWa var ini *ipn.Notify b.mu.Lock() + b.activeWatchSessions.Add(sessionID) + const initialBits = ipn.NotifyInitialState | ipn.NotifyInitialPrefs | ipn.NotifyInitialNetMap if mask&initialBits != 0 { ini = &ipn.Notify{Version: version.Long()} if mask&ipn.NotifyInitialState != 0 { + ini.SessionID = sessionID ini.State = ptr.To(b.state) if b.state == ipn.NeedsLogin { ini.BrowseToURL = ptr.To(b.authURLSticky) @@ -1939,6 +2064,7 @@ func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWa defer func() { b.mu.Lock() delete(b.notifyWatchers, handle) + delete(b.activeWatchSessions, sessionID) b.mu.Unlock() }() @@ -1969,6 +2095,10 @@ func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWa go b.pollRequestEngineStatus(ctx) } + // TODO(marwan-at-work): check err + // TODO(marwan-at-work): streaming background logs? + defer b.DeleteForegroundSession(sessionID) + for { select { case <-ctx.Done(): @@ -1984,11 +2114,11 @@ func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWa // pollRequestEngineStatus calls b.RequestEngineStatus every 2 seconds until ctx // is done. func (b *LocalBackend) pollRequestEngineStatus(ctx context.Context) { - ticker := time.NewTicker(2 * time.Second) + ticker, tickerChannel := b.clock.NewTicker(2 * time.Second) defer ticker.Stop() for { select { - case <-ticker.C: + case <-tickerChannel: b.RequestEngineStatus() case <-ctx.Done(): return @@ -2003,6 +2133,23 @@ func (b *LocalBackend) DebugNotify(n ipn.Notify) { b.send(n) } +// DebugForceNetmapUpdate forces a full no-op netmap update of the current +// netmap in all the various subsystems (wireguard, magicsock, LocalBackend). +// +// It exists for load testing reasons (for issue 1909), doing what would happen +// if a new MapResponse came in from the control server that couldn't be handled +// incrementally. +func (b *LocalBackend) DebugForceNetmapUpdate() { + b.mu.Lock() + defer b.mu.Unlock() + nm := b.netMap + b.e.SetNetworkMap(nm) + if nm != nil { + b.magicConn().SetDERPMap(nm.DERPMap) + } + b.setNetMapLocked(nm) +} + // send delivers n to the connected frontend and any API watchers from // LocalBackend.WatchNotifications (via the LocalAPI). // @@ -2124,6 +2271,9 @@ func (b *LocalBackend) tellClientToBrowseToURL(url string) { // onClientVersion is called on MapResponse updates when a MapResponse contains // a non-nil ClientVersion message. func (b *LocalBackend) onClientVersion(v *tailcfg.ClientVersion) { + b.mu.Lock() + b.lastClientVersion = v + b.mu.Unlock() switch runtime.GOOS { case "darwin", "ios": // These auto-update well enough, and we haven't converted the @@ -2205,7 +2355,7 @@ func (b *LocalBackend) initMachineKeyLocked() (err error) { } keyText, _ = b.machinePrivKey.MarshalText() - if err := b.store.WriteState(ipn.MachineKeyStateKey, keyText); err != nil { + if err := ipn.WriteState(b.store, ipn.MachineKeyStateKey, keyText); err != nil { b.logf("error writing machine key to store: %v", err) return err } @@ -2220,7 +2370,7 @@ func (b *LocalBackend) initMachineKeyLocked() (err error) { // // b.mu must be held. func (b *LocalBackend) clearMachineKeyLocked() error { - if err := b.store.WriteState(ipn.MachineKeyStateKey, nil); err != nil { + if err := ipn.WriteState(b.store, ipn.MachineKeyStateKey, nil); err != nil { return err } b.machinePrivKey = key.MachinePrivate{} @@ -2239,7 +2389,7 @@ func (b *LocalBackend) migrateStateLocked(prefs *ipn.Prefs) (err error) { // Backend owns the state, but frontend is trying to migrate // state into the backend. b.logf("importing frontend prefs into backend store; frontend prefs: %s", prefs.Pretty()) - if err := b.pm.SetPrefs(prefs.View()); err != nil { + if err := b.pm.SetPrefs(prefs.View(), b.netMap.MagicDNSSuffix()); err != nil { return fmt.Errorf("store.WriteState: %v", err) } } @@ -2292,12 +2442,13 @@ func (b *LocalBackend) setAtomicValuesFromPrefsLocked(p ipn.PrefsView) { b.sshAtomicBool.Store(p.Valid() && p.RunSSH() && envknob.CanSSHD()) if !p.Valid() { - b.containsViaIPFuncAtomic.Store(tsaddr.NewContainsIPFunc(nil)) + b.containsViaIPFuncAtomic.Store(tsaddr.FalseContainsIPFunc()) b.setTCPPortsIntercepted(nil) b.lastServeConfJSON = mem.B(nil) b.serveConfig = ipn.ServeConfigView{} } else { - b.containsViaIPFuncAtomic.Store(tsaddr.NewContainsIPFunc(p.AdvertiseRoutes().Filter(tsaddr.IsViaPrefix))) + filtered := tsaddr.FilterPrefixesCopy(p.AdvertiseRoutes(), tsaddr.IsViaPrefix) + b.containsViaIPFuncAtomic.Store(tsaddr.NewContainsIPFunc(views.SliceOf(filtered))) b.setTCPPortsInterceptedFromNetmapAndPrefsLocked(p) } } @@ -2396,14 +2547,14 @@ func (b *LocalBackend) StartLoginInteractive() { } } -func (b *LocalBackend) Ping(ctx context.Context, ip netip.Addr, pingType tailcfg.PingType) (*ipnstate.PingResult, error) { +func (b *LocalBackend) Ping(ctx context.Context, ip netip.Addr, pingType tailcfg.PingType, size int) (*ipnstate.PingResult, error) { if pingType == tailcfg.PingPeerAPI { - t0 := time.Now() + t0 := b.clock.Now() node, base, err := b.pingPeerAPI(ctx, ip) if err != nil && ctx.Err() != nil { return nil, ctx.Err() } - d := time.Since(t0) + d := b.clock.Since(t0) pr := &ipnstate.PingResult{ IP: ip.String(), NodeIP: ip.String(), @@ -2413,13 +2564,13 @@ func (b *LocalBackend) Ping(ctx context.Context, ip netip.Addr, pingType tailcfg if err != nil { pr.Err = err.Error() } - if node != nil { - pr.NodeName = node.Name + if node.Valid() { + pr.NodeName = node.Name() } return pr, nil } ch := make(chan *ipnstate.PingResult, 1) - b.e.Ping(ip, pingType, func(pr *ipnstate.PingResult) { + b.e.Ping(ip, pingType, size, func(pr *ipnstate.PingResult) { select { case ch <- pr: default: @@ -2433,36 +2584,37 @@ func (b *LocalBackend) Ping(ctx context.Context, ip netip.Addr, pingType tailcfg } } -func (b *LocalBackend) pingPeerAPI(ctx context.Context, ip netip.Addr) (peer *tailcfg.Node, peerBase string, err error) { +func (b *LocalBackend) pingPeerAPI(ctx context.Context, ip netip.Addr) (peer tailcfg.NodeView, peerBase string, err error) { + var zero tailcfg.NodeView ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() nm := b.NetMap() if nm == nil { - return nil, "", errors.New("no netmap") + return zero, "", errors.New("no netmap") } peer, ok := nm.PeerByTailscaleIP(ip) if !ok { - return nil, "", fmt.Errorf("no peer found with Tailscale IP %v", ip) + return zero, "", fmt.Errorf("no peer found with Tailscale IP %v", ip) } - if peer.Expired { - return nil, "", errors.New("peer's node key has expired") + if peer.Expired() { + return zero, "", errors.New("peer's node key has expired") } base := peerAPIBase(nm, peer) if base == "" { - return nil, "", fmt.Errorf("no PeerAPI base found for peer %v (%v)", peer.ID, ip) + return zero, "", fmt.Errorf("no PeerAPI base found for peer %v (%v)", peer.ID(), ip) } outReq, err := http.NewRequestWithContext(ctx, "HEAD", base, nil) if err != nil { - return nil, "", err + return zero, "", err } tr := b.Dialer().PeerAPITransport() res, err := tr.RoundTrip(outReq) if err != nil { - return nil, "", err + return zero, "", err } defer res.Body.Close() // but unnecessary on HEAD responses if res.StatusCode != http.StatusOK { - return nil, "", fmt.Errorf("HTTP status %v", res.Status) + return zero, "", fmt.Errorf("HTTP status %v", res.Status) } return peer, base, nil } @@ -2730,7 +2882,7 @@ func (b *LocalBackend) SetPrefs(newp *ipn.Prefs) { // doesn't affect security or correctness. And we also don't expect people to // modify their ServeConfig in raw mode. func (b *LocalBackend) wantIngressLocked() bool { - return b.serveConfig.Valid() && b.serveConfig.AllowFunnel().Len() > 0 + return b.serveConfig.Valid() && b.serveConfig.HasAllowFunnel() } // setPrefsLockedOnEntry requires b.mu be held to call it, but it @@ -2747,7 +2899,7 @@ func (b *LocalBackend) setPrefsLockedOnEntry(caller string, newp *ipn.Prefs) ipn // findExitNodeIDLocked returns whether it updated b.prefs, but // everything in this function treats b.prefs as completely new // anyway. No-op if no exit node resolution is needed. - findExitNodeIDLocked(newp, netMap) + setExitNodeID(newp, netMap) // We do this to avoid holding the lock while doing everything else. oldHi := b.hostinfo @@ -2770,22 +2922,22 @@ func (b *LocalBackend) setPrefsLockedOnEntry(caller string, newp *ipn.Prefs) ipn } } if netMap != nil { - up := netMap.UserProfiles[netMap.User] - if login := up.LoginName; login != "" { - if newp.Persist == nil { - b.logf("active login: %s", login) + newProfile := netMap.UserProfiles[netMap.User()] + if newLoginName := newProfile.LoginName; newLoginName != "" { + if !oldp.Persist().Valid() { + b.logf("active login: %s", newLoginName) } else { - if newp.Persist.LoginName != login { - b.logf("active login: %q (changed from %q)", login, newp.Persist.LoginName) - newp.Persist.LoginName = login + oldLoginName := oldp.Persist().UserProfile().LoginName() + if oldLoginName != newLoginName { + b.logf("active login: %q (changed from %q)", newLoginName, oldLoginName) } - newp.Persist.UserProfile = up + newp.Persist.UserProfile = newProfile } } } prefs := newp.View() - if err := b.pm.SetPrefs(prefs); err != nil { + if err := b.pm.SetPrefs(prefs, b.netMap.MagicDNSSuffix()); err != nil { b.logf("failed to save new controlclient state: %v", err) } b.lastProfileID = b.pm.CurrentProfile().ID @@ -2796,7 +2948,7 @@ func (b *LocalBackend) setPrefsLockedOnEntry(caller string, newp *ipn.Prefs) ipn } if netMap != nil { - b.e.SetDERPMap(netMap.DERPMap) + b.magicConn().SetDERPMap(netMap.DERPMap) } if !oldp.WantRunning() && newp.WantRunning { @@ -2850,7 +3002,7 @@ func (b *LocalBackend) handlePeerAPIConn(remote, local netip.AddrPort, c net.Con func (b *LocalBackend) isLocalIP(ip netip.Addr) bool { nm := b.NetMap() - return nm != nil && slices.Contains(nm.Addresses, netip.PrefixFrom(ip, ip.BitLen())) + return nm != nil && views.SliceContains(nm.GetAddresses(), netip.PrefixFrom(ip, ip.BitLen())) } var ( @@ -2983,7 +3135,9 @@ func (b *LocalBackend) authReconfig() { prefs := b.pm.CurrentPrefs() nm := b.netMap hasPAC := b.prevIfState.HasPAC() - disableSubnetsIfPAC := nm != nil && nm.Debug != nil && nm.Debug.DisableSubnetsIfPAC.EqualBool(true) + disableSubnetsIfPAC := hasCapability(nm, tailcfg.NodeAttrDisableSubnetsIfPAC) + dohURL, dohURLOK := exitNodeCanProxyDNS(nm, b.peers, prefs.ExitNodeID()) + dcfg := dnsConfigForNetmap(nm, b.peers, prefs, b.logf, version.OS()) b.mu.Unlock() if blocked { @@ -3016,7 +3170,7 @@ func (b *LocalBackend) authReconfig() { // Keep the dialer updated about whether we're supposed to use // an exit node's DNS server (so SOCKS5/HTTP outgoing dials // can use it for name resolution) - if dohURL, ok := exitNodeCanProxyDNS(nm, prefs.ExitNodeID()); ok { + if dohURLOK { b.dialer.SetExitDNSDoH(dohURL) } else { b.dialer.SetExitDNSDoH("") @@ -3028,11 +3182,10 @@ func (b *LocalBackend) authReconfig() { return } - oneCGNATRoute := shouldUseOneCGNATRoute(nm, b.logf, version.OS()) + oneCGNATRoute := shouldUseOneCGNATRoute(b.logf, b.sys.ControlKnobs(), version.OS()) rcfg := b.routerConfig(cfg, prefs, oneCGNATRoute) - dcfg := dnsConfigForNetmap(nm, prefs, b.logf, version.OS()) - err = b.e.Reconfig(cfg, rcfg, dcfg, nm.Debug) + err = b.e.Reconfig(cfg, rcfg, dcfg) if err == wgengine.ErrNoChanges { return } @@ -3046,14 +3199,15 @@ func (b *LocalBackend) authReconfig() { // // The versionOS is a Tailscale-style version ("iOS", "macOS") and not // a runtime.GOOS. -func shouldUseOneCGNATRoute(nm *netmap.NetworkMap, logf logger.Logf, versionOS string) bool { - // Explicit enabling or disabling always take precedence. - if nm.Debug != nil { - if v, ok := nm.Debug.OneCGNATRoute.Get(); ok { +func shouldUseOneCGNATRoute(logf logger.Logf, controlKnobs *controlknobs.Knobs, versionOS string) bool { + if controlKnobs != nil { + // Explicit enabling or disabling always take precedence. + if v, ok := controlKnobs.OneCGNAT.Load().Get(); ok { logf("[v1] shouldUseOneCGNATRoute: explicit=%v", v) return v } } + // Also prefer to do this on the Mac, so that we don't need to constantly // update the network extension configuration (which is disruptive to // Chrome, see https://github.com/tailscale/tailscale/issues/3102). Only @@ -3078,15 +3232,18 @@ func shouldUseOneCGNATRoute(nm *netmap.NetworkMap, logf logger.Logf, versionOS s // // The versionOS is a Tailscale-style version ("iOS", "macOS") and not // a runtime.GOOS. -func dnsConfigForNetmap(nm *netmap.NetworkMap, prefs ipn.PrefsView, logf logger.Logf, versionOS string) *dns.Config { +func dnsConfigForNetmap(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg.NodeView, prefs ipn.PrefsView, logf logger.Logf, versionOS string) *dns.Config { + if nm == nil { + return nil + } dcfg := &dns.Config{ Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, Hosts: map[dnsname.FQDN][]netip.Addr{}, } // selfV6Only is whether we only have IPv6 addresses ourselves. - selfV6Only := slices.ContainsFunc(nm.Addresses, tsaddr.PrefixIs6) && - !slices.ContainsFunc(nm.Addresses, tsaddr.PrefixIs4) + selfV6Only := views.SliceContainsFunc(nm.GetAddresses(), tsaddr.PrefixIs6) && + !views.SliceContainsFunc(nm.GetAddresses(), tsaddr.PrefixIs4) dcfg.OnlyIPv6 = selfV6Only // Populate MagicDNS records. We do this unconditionally so that @@ -3094,17 +3251,24 @@ func dnsConfigForNetmap(nm *netmap.NetworkMap, prefs ipn.PrefsView, logf logger. // isn't configured to make MagicDNS resolution truly // magic. Details in // https://github.com/tailscale/tailscale/issues/1886. - set := func(name string, addrs []netip.Prefix) { - if len(addrs) == 0 || name == "" { + set := func(name string, addrs views.Slice[netip.Prefix]) { + if addrs.Len() == 0 || name == "" { return } fqdn, err := dnsname.ToFQDN(name) if err != nil { return // TODO: propagate error? } - have4 := slices.ContainsFunc(addrs, tsaddr.PrefixIs4) + var have4 bool + for i := range addrs.LenIter() { + if addrs.At(i).Addr().Is4() { + have4 = true + break + } + } var ips []netip.Addr - for _, addr := range addrs { + for i := range addrs.LenIter() { + addr := addrs.At(i) if selfV6Only { if addr.Addr().Is6() { ips = append(ips, addr.Addr()) @@ -3126,9 +3290,9 @@ func dnsConfigForNetmap(nm *netmap.NetworkMap, prefs ipn.PrefsView, logf logger. } dcfg.Hosts[fqdn] = ips } - set(nm.Name, nm.Addresses) - for _, peer := range nm.Peers { - set(peer.Name, peer.Addresses) + set(nm.Name, nm.GetAddresses()) + for _, peer := range peers { + set(peer.Name(), peer.Addresses()) } for _, rec := range nm.DNS.ExtraRecords { switch rec.Type { @@ -3175,7 +3339,7 @@ func dnsConfigForNetmap(nm *netmap.NetworkMap, prefs ipn.PrefsView, logf logger. // If we're using an exit node and that exit node is new enough (1.19.x+) // to run a DoH DNS proxy, then send all our DNS traffic through it. - if dohURL, ok := exitNodeCanProxyDNS(nm, prefs.ExitNodeID()); ok { + if dohURL, ok := exitNodeCanProxyDNS(nm, peers, prefs.ExitNodeID()); ok { addDefault([]*dnstype.Resolver{{Addr: dohURL}}) return dcfg } @@ -3341,10 +3505,11 @@ func (b *LocalBackend) initPeerAPIListener() { return } - if len(b.netMap.Addresses) == len(b.peerAPIListeners) { + addrs := b.netMap.GetAddresses() + if addrs.Len() == len(b.peerAPIListeners) { allSame := true for i, pln := range b.peerAPIListeners { - if pln.ip != b.netMap.Addresses[i].Addr() { + if pln.ip != addrs.At(i).Addr() { allSame = false break } @@ -3358,11 +3523,11 @@ func (b *LocalBackend) initPeerAPIListener() { b.closePeerAPIListenersLocked() selfNode := b.netMap.SelfNode - if len(b.netMap.Addresses) == 0 || selfNode == nil { + if !selfNode.Valid() || b.netMap.GetAddresses().Len() == 0 { return } - fileRoot := b.fileRootLocked(selfNode.User) + fileRoot := b.fileRootLocked(selfNode.User()) if fileRoot == "" { b.logf("peerapi starting without Taildrop directory configured") } @@ -3379,7 +3544,8 @@ func (b *LocalBackend) initPeerAPIListener() { b.peerAPIServer = ps isNetstack := b.sys.IsNetstack() - for i, a := range b.netMap.Addresses { + for i := range addrs.LenIter() { + a := addrs.At(i) var ln net.Listener var err error skipListen := i > 0 && isNetstack @@ -3631,7 +3797,7 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State) { // Transitioning away from running. b.closePeerAPIListenersLocked() } - b.maybePauseControlClientLocked() + b.pauseOrResumeControlClientLocked() b.mu.Unlock() // prefs may change irrespective of state; WantRunning should be explicitly @@ -3650,7 +3816,7 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State) { b.blockEngineUpdates(true) fallthrough case ipn.Stopped: - err := b.e.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{}, nil) + err := b.e.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{}) if err != nil { b.logf("Reconfig(down): %v", err) } @@ -3663,11 +3829,12 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State) { // Needed so that UpdateEndpoints can run b.e.RequestStatus() case ipn.Running: - var addrs []string - for _, addr := range netMap.Addresses { - addrs = append(addrs, addr.Addr().String()) + var addrStrs []string + addrs := netMap.GetAddresses() + for i := range addrs.LenIter() { + addrStrs = append(addrStrs, addrs.At(i).Addr().String()) } - systemd.Status("Connected; %s; %s", activeLogin, strings.Join(addrs, " ")) + systemd.Status("Connected; %s; %s", activeLogin, strings.Join(addrStrs, " ")) case ipn.NoState: // Do nothing. default: @@ -3676,10 +3843,16 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State) { } +// hasNodeKey reports whether a non-zero node key is present in the current +// prefs. func (b *LocalBackend) hasNodeKey() bool { - // we can't use b.Prefs(), because it strips the keys, oops! b.mu.Lock() defer b.mu.Unlock() + return b.hasNodeKeyLocked() +} + +func (b *LocalBackend) hasNodeKeyLocked() bool { + // we can't use b.Prefs(), because it strips the keys, oops! p := b.pm.CurrentPrefs() return p.Valid() && p.Persist().Valid() && !p.Persist().PrivateNodeKey().IsZero() } @@ -3688,19 +3861,17 @@ func (b *LocalBackend) hasNodeKey() bool { func (b *LocalBackend) NodeKey() key.NodePublic { b.mu.Lock() defer b.mu.Unlock() - - p := b.pm.CurrentPrefs() - if !p.Valid() || !p.Persist().Valid() || p.Persist().PrivateNodeKey().IsZero() { + if !b.hasNodeKeyLocked() { return key.NodePublic{} } - - return p.Persist().PublicNodeKey() + return b.pm.CurrentPrefs().Persist().PublicNodeKey() } -// nextState returns the state the backend seems to be in, based on +// nextStateLocked returns the state the backend seems to be in, based on // its internal state. -func (b *LocalBackend) nextState() ipn.State { - b.mu.Lock() +// +// b.mu must be held +func (b *LocalBackend) nextStateLocked() ipn.State { var ( cc = b.cc netMap = b.netMap @@ -3716,10 +3887,9 @@ func (b *LocalBackend) nextState() ipn.State { wantRunning = p.WantRunning() loggedOut = p.LoggedOut() } - b.mu.Unlock() switch { - case !wantRunning && !loggedOut && !blocked && b.hasNodeKey(): + case !wantRunning && !loggedOut && !blocked && b.hasNodeKeyLocked(): return ipn.Stopped case netMap == nil: if (cc != nil && cc.AuthCantContinue()) || loggedOut { @@ -3751,7 +3921,7 @@ func (b *LocalBackend) nextState() ipn.State { // NetMap must be non-nil for us to get here. // The node key expired, need to relogin. return ipn.NeedsLogin - case netMap.MachineStatus != tailcfg.MachineAuthorized: + case netMap.GetMachineStatus() != tailcfg.MachineAuthorized: // TODO(crawshaw): handle tailcfg.MachineInvalid return ipn.NeedsMachineAuth case state == ipn.NeedsMachineAuth: @@ -3782,7 +3952,8 @@ func (b *LocalBackend) RequestEngineStatus() { // TODO(apenwarr): use a channel or something to prevent reentrancy? // Or maybe just call the state machine from fewer places. func (b *LocalBackend) stateMachine() { - b.enterState(b.nextState()) + b.mu.Lock() + b.enterStateLockedOnEntry(b.nextStateLocked()) } // stopEngineAndWait deconfigures the local network data plane, and @@ -3792,7 +3963,7 @@ func (b *LocalBackend) stateMachine() { // a status update that predates the "I've shut down" update. func (b *LocalBackend) stopEngineAndWait() { b.logf("stopEngineAndWait...") - b.e.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{}, nil) + b.e.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{}) b.requestEngineStatusAndWait() b.logf("stopEngineAndWait: done.") } @@ -3810,12 +3981,12 @@ func (b *LocalBackend) requestEngineStatusAndWait() { b.statusLock.Unlock() } -// resetControlClientLockedAsync sets b.cc to nil, and starts a -// goroutine to Shutdown the old client. It does not wait for the -// shutdown to complete. -func (b *LocalBackend) resetControlClientLockedAsync() { +// resetControlClientLocked sets b.cc to nil and returns the old value. If the +// returned value is non-nil, the caller must call Shutdown on it after +// releasing b.mu. +func (b *LocalBackend) resetControlClientLocked() controlclient.Client { if b.cc == nil { - return + return nil } // When we clear the control client, stop any outstanding netmap expiry @@ -3831,10 +4002,10 @@ func (b *LocalBackend) resetControlClientLockedAsync() { // will abort. b.numClientStatusCalls.Add(1) } - - go b.cc.Shutdown() + prev := b.cc b.cc = nil b.ccAuto = nil + return prev } // ResetForClientDisconnect resets the backend for GUI clients running @@ -3844,11 +4015,15 @@ func (b *LocalBackend) resetControlClientLockedAsync() { // don't want to the user to have to reauthenticate in the future // when they restart the GUI. func (b *LocalBackend) ResetForClientDisconnect() { - defer b.enterState(ipn.Stopped) - b.mu.Lock() - defer b.mu.Unlock() b.logf("LocalBackend.ResetForClientDisconnect") - b.resetControlClientLockedAsync() + + b.mu.Lock() + prevCC := b.resetControlClientLocked() + if prevCC != nil { + // Needs to happen without b.mu held. + defer prevCC.Shutdown() + } + b.setNetMapLocked(nil) b.pm.Reset() b.keyExpired = false @@ -3856,6 +4031,7 @@ func (b *LocalBackend) ResetForClientDisconnect() { b.authURLSticky = "" b.activeLogin = "" b.setAtomicValuesFromPrefsLocked(ipn.PrefsView{}) + b.enterStateLockedOnEntry(ipn.Stopped) } func (b *LocalBackend) ShouldRunSSH() bool { return b.sshAtomicBool.Load() && envknob.CanSSHD() } @@ -3870,27 +4046,30 @@ func (b *LocalBackend) ShouldHandleViaIP(ip netip.Addr) bool { return false } -// Logout tells the controlclient that we want to log out, and -// transitions the local engine to the logged-out state without -// waiting for controlclient to be in that state. -func (b *LocalBackend) Logout() { - b.logout(context.Background(), false) -} - -func (b *LocalBackend) LogoutSync(ctx context.Context) error { - return b.logout(ctx, true) -} - -func (b *LocalBackend) logout(ctx context.Context, sync bool) error { +// Logout logs out the current profile, if any, and waits for the logout to +// complete. +func (b *LocalBackend) Logout(ctx context.Context) error { b.mu.Lock() + if !b.hasNodeKeyLocked() { + // Already logged out. + b.mu.Unlock() + return nil + } cc := b.cc + + // Grab the current profile before we unlock the mutex, so that we can + // delete it later. + profile := b.pm.CurrentProfile() b.mu.Unlock() - b.EditPrefs(&ipn.MaskedPrefs{ + _, err := b.EditPrefs(&ipn.MaskedPrefs{ WantRunningSet: true, LoggedOutSet: true, Prefs: ipn.Prefs{WantRunning: false, LoggedOut: true}, }) + if err != nil { + return err + } // Clear any previous dial plan(s), if set. b.dialPlan.Store(nil) @@ -3906,15 +4085,16 @@ func (b *LocalBackend) logout(ctx context.Context, sync bool) error { return errors.New("no controlclient") } - var err error - if sync { - err = cc.Logout(ctx) - } else { - cc.StartLogout() + if err := cc.Logout(ctx); err != nil { + return err } - - b.stateMachine() - return err + b.mu.Lock() + if err := b.pm.DeleteProfile(profile.ID); err != nil { + b.mu.Unlock() + b.logf("error deleting profile: %v", err) + return err + } + return b.resetForProfileChangeLockedOnEntry() } // assertClientLocked crashes if there is no controlclient in this backend. @@ -3937,51 +4117,33 @@ func (b *LocalBackend) setNetInfo(ni *tailcfg.NetInfo) { cc.SetNetInfo(ni) } -func hasCapability(nm *netmap.NetworkMap, cap string) bool { - if nm != nil && nm.SelfNode != nil { - for _, c := range nm.SelfNode.Capabilities { - if c == cap { - return true - } - } +func hasCapability(nm *netmap.NetworkMap, cap tailcfg.NodeCapability) bool { + if nm != nil { + return nm.SelfNode.HasCap(cap) } return false } -func (b *LocalBackend) updatePersistFromNetMapLocked(nm *netmap.NetworkMap, prefs *ipn.Prefs) (changed bool) { - if nm == nil || nm.SelfNode == nil { - return - } - up := nm.UserProfiles[nm.User] - if prefs.Persist.UserProfile.ID != up.ID { - // If the current profile doesn't match the - // network map's user profile, then we need to - // update the persisted UserProfile to match. - prefs.Persist.UserProfile = up - changed = true - } - if prefs.Persist.NodeID == "" { - // If the current profile doesn't have a NodeID, - // then we need to update the persisted NodeID to - // match. - prefs.Persist.NodeID = nm.SelfNode.StableID - changed = true - } - return changed -} - +// setNetMapLocked updates the LocalBackend state to reflect the newly +// received nm. If nm is nil, it resets all configuration as though +// Tailscale is turned off. func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { b.dialer.SetNetMap(nm) + if ns, ok := b.sys.Netstack.GetOK(); ok { + ns.UpdateNetstackIPs(nm) + } var login string if nm != nil { - login = cmpx.Or(nm.UserProfiles[nm.User].LoginName, "") + login = cmpx.Or(nm.UserProfiles[nm.User()].LoginName, "") } b.netMap = nm + b.updatePeersFromNetmapLocked(nm) if login != b.activeLogin { b.logf("active login: %v", login) b.activeLogin = login + b.lastProfileID = b.pm.CurrentProfile().ID } - b.maybePauseControlClientLocked() + b.pauseOrResumeControlClientLocked() if nm != nil { health.SetControlHealth(nm.ControlHealth) @@ -4010,20 +4172,20 @@ func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { // Update the nodeByAddr index. if b.nodeByAddr == nil { - b.nodeByAddr = map[netip.Addr]*tailcfg.Node{} + b.nodeByAddr = map[netip.Addr]tailcfg.NodeID{} } // First pass, mark everything unwanted. for k := range b.nodeByAddr { - b.nodeByAddr[k] = nil + b.nodeByAddr[k] = 0 } - addNode := func(n *tailcfg.Node) { - for _, ipp := range n.Addresses { - if ipp.IsSingleIP() { - b.nodeByAddr[ipp.Addr()] = n + addNode := func(n tailcfg.NodeView) { + for i := range n.Addresses().LenIter() { + if ipp := n.Addresses().At(i); ipp.IsSingleIP() { + b.nodeByAddr[ipp.Addr()] = n.ID() } } } - if nm.SelfNode != nil { + if nm.SelfNode.Valid() { addNode(nm.SelfNode) } for _, p := range nm.Peers { @@ -4031,12 +4193,33 @@ func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { } // Third pass, actually delete the unwanted items. for k, v := range b.nodeByAddr { - if v == nil { + if v == 0 { delete(b.nodeByAddr, k) } } } +func (b *LocalBackend) updatePeersFromNetmapLocked(nm *netmap.NetworkMap) { + if nm == nil { + b.peers = nil + return + } + // First pass, mark everything unwanted. + for k := range b.peers { + b.peers[k] = tailcfg.NodeView{} + } + // Second pass, add everything wanted. + for _, p := range nm.Peers { + mak.Set(&b.peers, p.ID(), p) + } + // Third pass, remove deleted things. + for k, v := range b.peers { + if !v.Valid() { + delete(b.peers, k) + } + } +} + // setDebugLogsByCapabilityLocked sets debug logging based on the self node's // capabilities in the provided NetMap. func (b *LocalBackend) setDebugLogsByCapabilityLocked(nm *netmap.NetworkMap) { @@ -4049,14 +4232,19 @@ func (b *LocalBackend) setDebugLogsByCapabilityLocked(nm *netmap.NetworkMap) { } } +// reloadServeConfigLocked reloads the serve config from the store or resets the +// serve config to nil if not logged in. The "changed" parameter, when false, instructs +// the method to only run the reset-logic and not reload the store from memory to ensure +// foreground sessions are not removed if they are not saved on disk. func (b *LocalBackend) reloadServeConfigLocked(prefs ipn.PrefsView) { - if b.netMap == nil || b.netMap.SelfNode == nil || !prefs.Valid() || b.pm.CurrentProfile().ID == "" { + if b.netMap == nil || !b.netMap.SelfNode.Valid() || !prefs.Valid() || b.pm.CurrentProfile().ID == "" { // We're not logged in, so we don't have a profile. // Don't try to load the serve config. b.lastServeConfJSON = mem.B(nil) b.serveConfig = ipn.ServeConfigView{} return } + confKey := ipn.ServeConfigKey(b.pm.CurrentProfile().ID) // TODO(maisem,bradfitz): prevent reading the config from disk // if the profile has not changed. @@ -4076,6 +4264,12 @@ func (b *LocalBackend) reloadServeConfigLocked(prefs ipn.PrefsView) { b.serveConfig = ipn.ServeConfigView{} return } + + // remove inactive sessions + maps.DeleteFunc(conf.Foreground, func(s string, sc *ipn.ServeConfig) bool { + return !b.activeWatchSessions.Contains(s) + }) + b.serveConfig = conf.View() } @@ -4093,7 +4287,7 @@ func (b *LocalBackend) setTCPPortsInterceptedFromNetmapAndPrefsLocked(prefs ipn. b.reloadServeConfigLocked(prefs) if b.serveConfig.Valid() { servePorts := make([]uint16, 0, 3) - b.serveConfig.TCP().Range(func(port uint16, _ ipn.TCPPortHandlerView) bool { + b.serveConfig.RangeOverTCPs(func(port uint16, _ ipn.TCPPortHandlerView) bool { if port > 0 { servePorts = append(servePorts, uint16(port)) } @@ -4126,7 +4320,7 @@ func (b *LocalBackend) setServeProxyHandlersLocked() { return } var backends map[string]bool - b.serveConfig.Web().Range(func(_ ipn.HostPort, conf ipn.WebServerConfigView) (cont bool) { + b.serveConfig.RangeOverWebs(func(_ ipn.HostPort, conf ipn.WebServerConfigView) (cont bool) { conf.Handlers().Range(func(_ string, h ipn.HTTPHandlerView) (cont bool) { backend := h.Proxy() if backend == "" { @@ -4299,7 +4493,7 @@ func (b *LocalBackend) FileTargets() ([]*apitype.FileTarget, error) { if !b.capFileSharing { return nil, errors.New("file sharing not enabled by Tailscale admin") } - for _, p := range nm.Peers { + for _, p := range b.peers { if !b.peerIsTaildropTargetLocked(p) { continue } @@ -4308,11 +4502,13 @@ func (b *LocalBackend) FileTargets() ([]*apitype.FileTarget, error) { continue } ret = append(ret, &apitype.FileTarget{ - Node: p, + Node: p.AsStruct(), PeerAPIURL: peerAPI, }) } - // TODO: sort a different way than the netmap already is? + slices.SortFunc(ret, func(a, b *apitype.FileTarget) int { + return cmpx.Compare(a.Node.Name, b.Node.Name) + }) return ret, nil } @@ -4321,28 +4517,23 @@ func (b *LocalBackend) FileTargets() ([]*apitype.FileTarget, error) { // the netmap. // // b.mu must be locked. -func (b *LocalBackend) peerIsTaildropTargetLocked(p *tailcfg.Node) bool { - if b.netMap == nil || p == nil { +func (b *LocalBackend) peerIsTaildropTargetLocked(p tailcfg.NodeView) bool { + if b.netMap == nil || !p.Valid() { return false } - if b.netMap.User == p.User { + if b.netMap.User() == p.User() { return true } - if len(p.Addresses) > 0 && - b.peerHasCapLocked(p.Addresses[0].Addr(), tailcfg.CapabilityFileSharingTarget) { + if p.Addresses().Len() > 0 && + b.peerHasCapLocked(p.Addresses().At(0).Addr(), tailcfg.PeerCapabilityFileSharingTarget) { // Explicitly noted in the netmap ACL caps as a target. return true } return false } -func (b *LocalBackend) peerHasCapLocked(addr netip.Addr, wantCap string) bool { - for _, hasCap := range b.peerCapsLocked(addr) { - if hasCap == wantCap { - return true - } - } - return false +func (b *LocalBackend) peerHasCapLocked(addr netip.Addr, wantCap tailcfg.PeerCapability) bool { + return b.peerCapsLocked(addr).HasCapability(wantCap) } // SetDNS adds a DNS record for the given domain name & TXT record @@ -4394,9 +4585,9 @@ func (b *LocalBackend) registerIncomingFile(inf *incomingFile, active bool) { } } -func peerAPIPorts(peer *tailcfg.Node) (p4, p6 uint16) { - svcs := peer.Hostinfo.Services() - for i, n := 0, svcs.Len(); i < n; i++ { +func peerAPIPorts(peer tailcfg.NodeView) (p4, p6 uint16) { + svcs := peer.Hostinfo().Services() + for i := range svcs.LenIter() { s := svcs.At(i) switch s.Proto { case tailcfg.PeerAPI4: @@ -4422,13 +4613,15 @@ func peerAPIURL(ip netip.Addr, port uint16) string { // peerAPIBase returns the "http://ip:port" URL base to reach peer's peerAPI. // It returns the empty string if the peer doesn't support the peerapi // or there's no matching address family based on the netmap's own addresses. -func peerAPIBase(nm *netmap.NetworkMap, peer *tailcfg.Node) string { - if nm == nil || peer == nil || !peer.Hostinfo.Valid() { +func peerAPIBase(nm *netmap.NetworkMap, peer tailcfg.NodeView) string { + if nm == nil || !peer.Valid() || !peer.Hostinfo().Valid() { return "" } var have4, have6 bool - for _, a := range nm.Addresses { + addrs := nm.GetAddresses() + for i := range addrs.LenIter() { + a := addrs.At(i) if !a.IsSingleIP() { continue } @@ -4449,8 +4642,9 @@ func peerAPIBase(nm *netmap.NetworkMap, peer *tailcfg.Node) string { return "" } -func nodeIP(n *tailcfg.Node, pred func(netip.Addr) bool) netip.Addr { - for _, a := range n.Addresses { +func nodeIP(n tailcfg.NodeView, pred func(netip.Addr) bool) netip.Addr { + for i := range n.Addresses().LenIter() { + a := n.Addresses().At(i) if a.IsSingleIP() && pred(a.Addr()) { return a.Addr() } @@ -4464,7 +4658,7 @@ func (b *LocalBackend) CheckIPForwarding() error { } // TODO: let the caller pass in the ranges. - warn, err := netutil.CheckIPForwarding(tsaddr.ExitRoutes(), nil) + warn, err := netutil.CheckIPForwarding(tsaddr.ExitRoutes(), b.sys.NetMon.Get().InterfaceState()) if err != nil { return err } @@ -4555,20 +4749,20 @@ func (b *LocalBackend) SetExpirySooner(ctx context.Context, expiry time.Time) er // to exitNodeID's DoH service, if available. // // If exitNodeID is the zero valid, it returns "", false. -func exitNodeCanProxyDNS(nm *netmap.NetworkMap, exitNodeID tailcfg.StableNodeID) (dohURL string, ok bool) { +func exitNodeCanProxyDNS(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg.NodeView, exitNodeID tailcfg.StableNodeID) (dohURL string, ok bool) { if exitNodeID.IsZero() { return "", false } - for _, p := range nm.Peers { - if p.StableID == exitNodeID && peerCanProxyDNS(p) { + for _, p := range peers { + if p.StableID() == exitNodeID && peerCanProxyDNS(p) { return peerAPIBase(nm, p) + "/dns-query", true } } return "", false } -func peerCanProxyDNS(p *tailcfg.Node) bool { - if p.Cap >= 26 { +func peerCanProxyDNS(p tailcfg.NodeView) bool { + if p.Cap() >= 26 { // Actually added at 25 // (https://github.com/tailscale/tailscale/blob/3ae6f898cfdb58fd0e30937147dd6ce28c6808dd/tailcfg/tailcfg.go#L51) // so anything >= 26 can do it. @@ -4576,10 +4770,9 @@ func peerCanProxyDNS(p *tailcfg.Node) bool { } // If p.Cap is not populated (e.g. older control server), then do the old // thing of searching through services. - services := p.Hostinfo.Services() - for i, n := 0, services.Len(); i < n; i++ { - s := services.At(i) - if s.Proto == tailcfg.PeerAPIDNS && s.Port >= 1 { + services := p.Hostinfo().Services() + for i := range services.LenIter() { + if s := services.At(i); s.Proto == tailcfg.PeerAPIDNS && s.Port >= 1 { return true } } @@ -4587,29 +4780,22 @@ func peerCanProxyDNS(p *tailcfg.Node) bool { } func (b *LocalBackend) DebugRebind() error { - mc, err := b.magicConn() - if err != nil { - return err - } - mc.Rebind() + b.magicConn().Rebind() return nil } func (b *LocalBackend) DebugReSTUN() error { - mc, err := b.magicConn() - if err != nil { - return err - } - mc.ReSTUN("explicit-debug") + b.magicConn().ReSTUN("explicit-debug") return nil } -func (b *LocalBackend) magicConn() (*magicsock.Conn, error) { - mc, ok := b.sys.MagicSock.GetOK() - if !ok { - return nil, errors.New("failed to get magicsock from sys") - } - return mc, nil +// ControlKnobs returns the node's control knobs. +func (b *LocalBackend) ControlKnobs() *controlknobs.Knobs { + return b.sys.ControlKnobs() +} + +func (b *LocalBackend) magicConn() *magicsock.Conn { + return b.sys.MagicSock.Get() } type keyProvingNoiseRoundTripper struct { @@ -4763,13 +4949,14 @@ func (b *LocalBackend) handleQuad100Port80Conn(w http.ResponseWriter, r *http.Re io.WriteString(w, "No netmap.\n") return } - if len(b.netMap.Addresses) == 0 { + addrs := b.netMap.GetAddresses() + if addrs.Len() == 0 { io.WriteString(w, "No local addresses.\n") return } io.WriteString(w, "

Local addresses:

    \n") - for _, ipp := range b.netMap.Addresses { - fmt.Fprintf(w, "
  • %v
  • \n", ipp.Addr()) + for i := range addrs.LenIter() { + fmt.Fprintf(w, "
  • %v
  • \n", addrs.At(i).Addr()) } io.WriteString(w, "
\n") } @@ -4779,7 +4966,7 @@ func (b *LocalBackend) Doctor(ctx context.Context, logf logger.Logf) { // opting-out of rate limits. Limit ourselves to at most one message // per 20ms and a burst of 60 log lines, which should be fast enough to // not block for too long but slow enough that we can upload all lines. - logf = logger.SlowLoggerWithClock(ctx, logf, 20*time.Millisecond, 60, time.Now) + logf = logger.SlowLoggerWithClock(ctx, logf, 20*time.Millisecond, 60, b.clock.Now) var checks []doctor.Check checks = append(checks, @@ -4831,7 +5018,7 @@ func (b *LocalBackend) SetDevStateStore(key, value string) error { if b.store == nil { return errors.New("no state store") } - err := b.store.WriteState(ipn.StateKey(key), []byte(value)) + err := ipn.WriteState(b.store, ipn.StateKey(key), []byte(value)) b.logf("SetDevStateStore(%q, %q) = %v", key, value, err) if err != nil { @@ -4915,16 +5102,26 @@ func (b *LocalBackend) initTKALocked() error { } // resetForProfileChangeLockedOnEntry resets the backend for a profile change. +// +// b.mu must held on entry. It is released on exit. func (b *LocalBackend) resetForProfileChangeLockedOnEntry() error { + if b.shutdownCalled { + // Prevent a call back to Start during Shutdown, which calls Logout for + // ephemeral nodes, which can then call back here. But we're shutting + // down, so no need to do any work. + b.mu.Unlock() + return nil + } b.setNetMapLocked(nil) // Reset netmap. // Reset the NetworkMap in the engine b.e.SetNetworkMap(new(netmap.NetworkMap)) if err := b.initTKALocked(); err != nil { + b.mu.Unlock() return err } b.lastServeConfJSON = mem.B(nil) b.serveConfig = ipn.ServeConfigView{} - b.enterStateLockedOnEntry(ipn.NoState) // Reset state. + b.enterStateLockedOnEntry(ipn.NoState) // Reset state; releases b.mu health.SetLocalLogConfigHealth(nil) return b.Start(ipn.Options{}) } @@ -4975,7 +5172,10 @@ func (b *LocalBackend) ListProfiles() []ipn.LoginProfile { // called to register it as new node. func (b *LocalBackend) ResetAuth() error { b.mu.Lock() - b.resetControlClientLockedAsync() + prevCC := b.resetControlClientLocked() + if prevCC != nil { + defer prevCC.Shutdown() // call must happen after release b.mu + } if err := b.clearMachineKeyLocked(); err != nil { b.mu.Unlock() return err @@ -5039,14 +5239,22 @@ func (b *LocalBackend) GetPeerEndpointChanges(ctx context.Context, ip netip.Addr } peer := pip.Node - mc, err := b.magicConn() - if err != nil { - return nil, fmt.Errorf("getting magicsock conn: %w", err) - } - - chs, err := mc.GetEndpointChanges(peer) + chs, err := b.magicConn().GetEndpointChanges(peer) if err != nil { return nil, fmt.Errorf("getting endpoint changes: %w", err) } return chs, nil } + +var breakTCPConns func() error + +func (b *LocalBackend) DebugBreakTCPConns() error { + if breakTCPConns == nil { + return errors.New("TCP connection breaking not available on this platform") + } + return breakTCPConns() +} + +func (b *LocalBackend) DebugBreakDERPConns() error { + return b.magicConn().DebugBreakDERPConns() +} diff --git a/vendor/tailscale.com/ipn/ipnlocal/network-lock.go b/vendor/tailscale.com/ipn/ipnlocal/network-lock.go index b2d09b6279..11cebcca38 100644 --- a/vendor/tailscale.com/ipn/ipnlocal/network-lock.go +++ b/vendor/tailscale.com/ipn/ipnlocal/network-lock.go @@ -20,8 +20,8 @@ import ( "path/filepath" "time" - "tailscale.com/envknob" "tailscale.com/health" + "tailscale.com/health/healthmsg" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/net/tsaddr" @@ -53,20 +53,12 @@ type tkaState struct { filtered []ipnstate.TKAFilteredPeer } -// permitTKAInitLocked returns true if tailnet lock initialization may -// occur. -// b.mu must be held. -func (b *LocalBackend) permitTKAInitLocked() bool { - return envknob.UseWIPCode() || b.capTailnetLock -} - // tkaFilterNetmapLocked checks the signatures on each node key, dropping // nodes from the netmap whose signature does not verify. // // b.mu must be held. func (b *LocalBackend) tkaFilterNetmapLocked(nm *netmap.NetworkMap) { - // TODO(tom): Remove this guard for 1.35 and later. - if b.tka == nil && !b.permitTKAInitLocked() { + if b.tka == nil && !b.capTailnetLock { health.SetTKAHealth(nil) return } @@ -77,16 +69,16 @@ func (b *LocalBackend) tkaFilterNetmapLocked(nm *netmap.NetworkMap) { var toDelete map[int]bool // peer index => true for i, p := range nm.Peers { - if p.UnsignedPeerAPIOnly { + if p.UnsignedPeerAPIOnly() { // Not subject to tailnet lock. continue } - if len(p.KeySignature) == 0 { - b.logf("Network lock is dropping peer %v(%v) due to missing signature", p.ID, p.StableID) + if p.KeySignature().Len() == 0 { + b.logf("Network lock is dropping peer %v(%v) due to missing signature", p.ID(), p.StableID()) mak.Set(&toDelete, i, true) } else { - if err := b.tka.authority.NodeKeyAuthorized(p.Key, p.KeySignature); err != nil { - b.logf("Network lock is dropping peer %v(%v) due to failed signature check: %v", p.ID, p.StableID, err) + if err := b.tka.authority.NodeKeyAuthorized(p.Key(), p.KeySignature().AsSlice()); err != nil { + b.logf("Network lock is dropping peer %v(%v) due to failed signature check: %v", p.ID(), p.StableID(), err) mak.Set(&toDelete, i, true) } } @@ -94,7 +86,7 @@ func (b *LocalBackend) tkaFilterNetmapLocked(nm *netmap.NetworkMap) { // nm.Peers is ordered, so deletion must be order-preserving. if len(toDelete) > 0 { - peers := make([]*tailcfg.Node, 0, len(nm.Peers)) + peers := make([]tailcfg.NodeView, 0, len(nm.Peers)) filtered := make([]ipnstate.TKAFilteredPeer, 0, len(toDelete)) for i, p := range nm.Peers { if !toDelete[i] { @@ -102,13 +94,14 @@ func (b *LocalBackend) tkaFilterNetmapLocked(nm *netmap.NetworkMap) { } else { // Record information about the node we filtered out. fp := ipnstate.TKAFilteredPeer{ - Name: p.Name, - ID: p.ID, - StableID: p.StableID, - TailscaleIPs: make([]netip.Addr, len(p.Addresses)), - NodeKey: p.Key, + Name: p.Name(), + ID: p.ID(), + StableID: p.StableID(), + TailscaleIPs: make([]netip.Addr, p.Addresses().Len()), + NodeKey: p.Key(), } - for i, addr := range p.Addresses { + for i := range p.Addresses().LenIter() { + addr := p.Addresses().At(i) if addr.IsSingleIP() && tsaddr.IsTailscaleIP(addr.Addr()) { fp.TailscaleIPs[i] = addr.Addr() } @@ -123,8 +116,8 @@ func (b *LocalBackend) tkaFilterNetmapLocked(nm *netmap.NetworkMap) { } // Check that we ourselves are not locked out, report a health issue if so. - if nm.SelfNode != nil && b.tka.authority.NodeKeyAuthorized(nm.SelfNode.Key, nm.SelfNode.KeySignature) != nil { - health.SetTKAHealth(errors.New("this node is locked out; it will not have connectivity until it is signed. For more info, see https://tailscale.com/s/locked-out")) + if nm.SelfNode.Valid() && b.tka.authority.NodeKeyAuthorized(nm.SelfNode.Key(), nm.SelfNode.KeySignature().AsSlice()) != nil { + health.SetTKAHealth(errors.New(healthmsg.LockedOut)) } else { health.SetTKAHealth(nil) } @@ -153,8 +146,7 @@ func (b *LocalBackend) tkaSyncIfNeeded(nm *netmap.NetworkMap, prefs ipn.PrefsVie b.mu.Lock() // take mu to protect access to synchronized fields. defer b.mu.Unlock() - // TODO(tom): Remove this guard for 1.35 and later. - if b.tka == nil && !b.permitTKAInitLocked() { + if b.tka == nil && !b.capTailnetLock { return nil } @@ -433,7 +425,7 @@ func (b *LocalBackend) NetworkLockStatus() *ipnstate.NetworkLockStatus { var selfAuthorized bool if b.netMap != nil { - selfAuthorized = b.tka.authority.NodeKeyAuthorized(b.netMap.SelfNode.Key, b.netMap.SelfNode.KeySignature) == nil + selfAuthorized = b.tka.authority.NodeKeyAuthorized(b.netMap.SelfNode.Key(), b.netMap.SelfNode.KeySignature().AsSlice()) == nil } keys := b.tka.authority.Keys() @@ -483,10 +475,9 @@ func (b *LocalBackend) NetworkLockInit(keys []tka.Key, disablementValues [][]byt var nlPriv key.NLPrivate b.mu.Lock() - // TODO(tom): Remove this guard for 1.35 and later. - if !b.permitTKAInitLocked() { + if !b.capTailnetLock { b.mu.Unlock() - return errors.New("this feature is not yet complete, a later release may support this functionality") + return errors.New("not permitted to enable tailnet lock") } if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist().Valid() && !p.Persist().PrivateNodeKey().IsZero() { @@ -587,7 +578,7 @@ func (b *LocalBackend) NetworkLockForceLocalDisable() error { newPrefs := b.pm.CurrentPrefs().AsStruct().Clone() // .Persist should always be initialized here. newPrefs.Persist.DisallowedTKAStateIDs = append(newPrefs.Persist.DisallowedTKAStateIDs, stateID) - if err := b.pm.SetPrefs(newPrefs.View()); err != nil { + if err := b.pm.SetPrefs(newPrefs.View(), b.netMap.MagicDNSSuffix()); err != nil { return fmt.Errorf("saving prefs: %w", err) } @@ -855,6 +846,93 @@ func (b *LocalBackend) NetworkLockAffectedSigs(keyID tkatype.KeyID) ([]tkatype.M return resp.Signatures, nil } +// NetworkLockGenerateRecoveryAUM generates an AUM which retroactively removes trust in the +// specified keys. This AUM is signed by the current node and returned. +// +// If forkFrom is specified, it is used as the parent AUM to fork from. If the zero value, +// the parent AUM is determined automatically. +func (b *LocalBackend) NetworkLockGenerateRecoveryAUM(removeKeys []tkatype.KeyID, forkFrom tka.AUMHash) (*tka.AUM, error) { + b.mu.Lock() + defer b.mu.Unlock() + if b.tka == nil { + return nil, errNetworkLockNotActive + } + var nlPriv key.NLPrivate + if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist().Valid() { + nlPriv = p.Persist().NetworkLockKey() + } + if nlPriv.IsZero() { + return nil, errMissingNetmap + } + + aum, err := b.tka.authority.MakeRetroactiveRevocation(b.tka.storage, removeKeys, nlPriv.KeyID(), forkFrom) + if err != nil { + return nil, err + } + + // Sign it ourselves. + aum.Signatures, err = nlPriv.SignAUM(aum.SigHash()) + if err != nil { + return nil, fmt.Errorf("signing failed: %w", err) + } + + return aum, nil +} + +// NetworkLockCosignRecoveryAUM co-signs the provided recovery AUM and returns +// the updated structure. +// +// The recovery AUM provided should be the output from a previous call to +// NetworkLockGenerateRecoveryAUM or NetworkLockCosignRecoveryAUM. +func (b *LocalBackend) NetworkLockCosignRecoveryAUM(aum *tka.AUM) (*tka.AUM, error) { + b.mu.Lock() + defer b.mu.Unlock() + if b.tka == nil { + return nil, errNetworkLockNotActive + } + var nlPriv key.NLPrivate + if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist().Valid() { + nlPriv = p.Persist().NetworkLockKey() + } + if nlPriv.IsZero() { + return nil, errMissingNetmap + } + for _, sig := range aum.Signatures { + if bytes.Equal(sig.KeyID, nlPriv.KeyID()) { + return nil, errors.New("this node has already signed this recovery AUM") + } + } + + // Sign it ourselves. + sigs, err := nlPriv.SignAUM(aum.SigHash()) + if err != nil { + return nil, fmt.Errorf("signing failed: %w", err) + } + aum.Signatures = append(aum.Signatures, sigs...) + + return aum, nil +} + +func (b *LocalBackend) NetworkLockSubmitRecoveryAUM(aum *tka.AUM) error { + b.mu.Lock() + defer b.mu.Unlock() + if b.tka == nil { + return errNetworkLockNotActive + } + var ourNodeKey key.NodePublic + if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist().Valid() && !p.Persist().PrivateNodeKey().IsZero() { + ourNodeKey = p.Persist().PublicNodeKey() + } + if ourNodeKey.IsZero() { + return errors.New("no node-key: is tailscale logged in?") + } + + b.mu.Unlock() + _, err := b.tkaDoSyncSend(ourNodeKey, aum.Hash(), []tka.AUM{*aum}, false) + b.mu.Lock() + return err +} + var tkaSuffixEncoder = base64.RawStdEncoding // NetworkLockWrapPreauthKey wraps a pre-auth key with information to diff --git a/vendor/tailscale.com/ipn/ipnlocal/peerapi.go b/vendor/tailscale.com/ipn/ipnlocal/peerapi.go index 6d5f8c0fbc..a02ee2e400 100644 --- a/vendor/tailscale.com/ipn/ipnlocal/peerapi.go +++ b/vendor/tailscale.com/ipn/ipnlocal/peerapi.go @@ -22,6 +22,7 @@ import ( "path" "path/filepath" "runtime" + "slices" "sort" "strconv" "strings" @@ -32,7 +33,6 @@ import ( "unicode/utf8" "github.com/kortschak/wol" - "golang.org/x/exp/slices" "golang.org/x/net/dns/dnsmessage" "golang.org/x/net/http/httpguts" "tailscale.com/client/tailscale/apitype" @@ -47,6 +47,7 @@ import ( "tailscale.com/net/netutil" "tailscale.com/net/sockstats" "tailscale.com/tailcfg" + "tailscale.com/types/views" "tailscale.com/util/clientmetric" "tailscale.com/util/multierr" "tailscale.com/version/distro" @@ -135,6 +136,9 @@ func (s *peerAPIServer) diskPath(baseName string) (fullPath string, ok bool) { return "", false } } + if !filepath.IsLocal(baseName) { + return "", false + } return filepath.Join(s.rootDir, baseName), true } @@ -304,7 +308,7 @@ func (s *peerAPIServer) DeleteFile(baseName string) error { } var bo *backoff.Backoff logf := s.b.logf - t0 := time.Now() + t0 := s.b.clock.Now() for { err := os.Remove(path) if err != nil && !os.IsNotExist(err) { @@ -323,7 +327,7 @@ func (s *peerAPIServer) DeleteFile(baseName string) error { if bo == nil { bo = backoff.NewBackoff("delete-retry", logf, 1*time.Second) } - if time.Since(t0) < 5*time.Second { + if s.b.clock.Since(t0) < 5*time.Second { bo.BackOff(context.Background(), err) continue } @@ -569,14 +573,14 @@ func (pln *peerAPIListener) ServeConn(src netip.AddrPort, c net.Conn) { return } nm := pln.lb.NetMap() - if nm == nil || nm.SelfNode == nil { + if nm == nil || !nm.SelfNode.Valid() { logf("peerapi: no netmap") c.Close() return } h := &peerAPIHandler{ ps: pln.ps, - isSelf: nm.SelfNode.User == peerNode.User, + isSelf: nm.SelfNode.User() == peerNode.User(), remoteAddr: src, selfNode: nm.SelfNode, peerNode: peerNode, @@ -596,8 +600,8 @@ type peerAPIHandler struct { ps *peerAPIServer remoteAddr netip.AddrPort isSelf bool // whether peerNode is owned by same user as this node - selfNode *tailcfg.Node // this node; always non-nil - peerNode *tailcfg.Node // peerNode is who's making the request + selfNode tailcfg.NodeView // this node; always non-nil + peerNode tailcfg.NodeView // peerNode is who's making the request peerUser tailcfg.UserProfile // profile of peerNode } @@ -608,11 +612,14 @@ func (h *peerAPIHandler) logf(format string, a ...any) { // isAddressValid reports whether addr is a valid destination address for this // node originating from the peer. func (h *peerAPIHandler) isAddressValid(addr netip.Addr) bool { - if h.peerNode.SelfNodeV4MasqAddrForThisPeer != nil { - return *h.peerNode.SelfNodeV4MasqAddrForThisPeer == addr + if v := h.peerNode.SelfNodeV4MasqAddrForThisPeer(); v != nil { + return *v == addr + } + if v := h.peerNode.SelfNodeV6MasqAddrForThisPeer(); v != nil { + return *v == addr } pfx := netip.PrefixFrom(addr, addr.BitLen()) - return slices.Contains(h.selfNode.Addresses, pfx) + return views.SliceContains(h.selfNode.Addresses(), pfx) } func (h *peerAPIHandler) validateHost(r *http.Request) error { @@ -733,7 +740,7 @@ func (h *peerAPIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

Hello, %s (%v)

This is my Tailscale device. Your device is %v. -`, html.EscapeString(who), h.remoteAddr.Addr(), html.EscapeString(h.peerNode.ComputedName)) +`, html.EscapeString(who), h.remoteAddr.Addr(), html.EscapeString(h.peerNode.ComputedName())) if h.isSelf { fmt.Fprintf(w, "

You are the owner of this node.\n") @@ -902,8 +909,8 @@ func (h *peerAPIHandler) handleServeSockStats(w http.ResponseWriter, r *http.Req for label := range stats.Stats { labels = append(labels, label) } - slices.SortFunc(labels, func(a, b sockstats.Label) bool { - return a.String() < b.String() + slices.SortFunc(labels, func(a, b sockstats.Label) int { + return strings.Compare(a.String(), b.String()) }) txTotal := uint64(0) @@ -1000,7 +1007,7 @@ func (f *incomingFile) Write(p []byte) (n int, err error) { f.mu.Lock() defer f.mu.Unlock() f.copied += int64(n) - now := time.Now() + now := b.clock.Now() if f.lastNotify.IsZero() || now.Sub(f.lastNotify) > time.Second { f.lastNotify = now needNotify = true @@ -1024,49 +1031,44 @@ func (f *incomingFile) PartialFile() ipn.PartialFile { // canPutFile reports whether h can put a file ("Taildrop") to this node. func (h *peerAPIHandler) canPutFile() bool { - if h.peerNode.UnsignedPeerAPIOnly { + if h.peerNode.UnsignedPeerAPIOnly() { // Unsigned peers can't send files. return false } - return h.isSelf || h.peerHasCap(tailcfg.CapabilityFileSharingSend) + return h.isSelf || h.peerHasCap(tailcfg.PeerCapabilityFileSharingSend) } // canDebug reports whether h can debug this node (goroutines, metrics, // magicsock internal state, etc). func (h *peerAPIHandler) canDebug() bool { - if !slices.Contains(h.selfNode.Capabilities, tailcfg.CapabilityDebug) { + if !h.selfNode.HasCap(tailcfg.CapabilityDebug) { // This node does not expose debug info. return false } - if h.peerNode.UnsignedPeerAPIOnly { + if h.peerNode.UnsignedPeerAPIOnly() { // Unsigned peers can't debug. return false } - return h.isSelf || h.peerHasCap(tailcfg.CapabilityDebugPeer) + return h.isSelf || h.peerHasCap(tailcfg.PeerCapabilityDebugPeer) } // canWakeOnLAN reports whether h can send a Wake-on-LAN packet from this node. func (h *peerAPIHandler) canWakeOnLAN() bool { - if h.peerNode.UnsignedPeerAPIOnly { + if h.peerNode.UnsignedPeerAPIOnly() { return false } - return h.isSelf || h.peerHasCap(tailcfg.CapabilityWakeOnLAN) + return h.isSelf || h.peerHasCap(tailcfg.PeerCapabilityWakeOnLAN) } var allowSelfIngress = envknob.RegisterBool("TS_ALLOW_SELF_INGRESS") // canIngress reports whether h can send ingress requests to this node. func (h *peerAPIHandler) canIngress() bool { - return h.peerHasCap(tailcfg.CapabilityIngress) || (allowSelfIngress() && h.isSelf) + return h.peerHasCap(tailcfg.PeerCapabilityIngress) || (allowSelfIngress() && h.isSelf) } -func (h *peerAPIHandler) peerHasCap(wantCap string) bool { - for _, hasCap := range h.ps.b.PeerCaps(h.remoteAddr.Addr()) { - if hasCap == wantCap { - return true - } - } - return false +func (h *peerAPIHandler) peerHasCap(wantCap tailcfg.PeerCapability) bool { + return h.ps.b.PeerCaps(h.remoteAddr.Addr()).HasCapability(wantCap) } func (h *peerAPIHandler) handlePeerPut(w http.ResponseWriter, r *http.Request) { @@ -1118,7 +1120,7 @@ func (h *peerAPIHandler) handlePeerPut(w http.ResponseWriter, r *http.Request) { http.Error(w, "bad filename", 400) return } - t0 := time.Now() + t0 := h.ps.b.clock.Now() // TODO(bradfitz): prevent same filename being sent by two peers at once partialFile := dstFile + partialSuffix f, err := os.Create(partialFile) @@ -1138,7 +1140,7 @@ func (h *peerAPIHandler) handlePeerPut(w http.ResponseWriter, r *http.Request) { if r.ContentLength != 0 { inFile = &incomingFile{ name: baseName, - started: time.Now(), + started: h.ps.b.clock.Now(), size: r.ContentLength, w: f, ph: h, @@ -1176,7 +1178,7 @@ func (h *peerAPIHandler) handlePeerPut(w http.ResponseWriter, r *http.Request) { } } - d := time.Since(t0).Round(time.Second / 10) + d := h.ps.b.clock.Since(t0).Round(time.Second / 10) h.logf("got put of %s in %v from %v/%v", approxSize(finalSize), d, h.remoteAddr.Addr(), h.peerNode.ComputedName) // TODO: set modtime @@ -1238,11 +1240,7 @@ func (h *peerAPIHandler) handleServeMagicsock(w http.ResponseWriter, r *http.Req http.Error(w, "denied; no debug access", http.StatusForbidden) return } - if mc, ok := h.ps.b.sys.MagicSock.GetOK(); ok { - mc.ServeHTTPDebug(w, r) - return - } - http.Error(w, "miswired", 500) + h.ps.b.magicConn().ServeHTTPDebug(w, r) } func (h *peerAPIHandler) handleServeMetrics(w http.ResponseWriter, r *http.Request) { @@ -1287,8 +1285,8 @@ func (h *peerAPIHandler) handleWakeOnLAN(w http.ResponseWriter, r *http.Request) return } var password []byte // TODO(bradfitz): support? - st, err := interfaces.GetState() - if err != nil { + st := h.ps.b.sys.NetMon.Get().InterfaceState() + if st == nil { http.Error(w, "failed to get interfaces state", http.StatusInternalServerError) return } diff --git a/vendor/tailscale.com/ipn/ipnlocal/profiles.go b/vendor/tailscale.com/ipn/ipnlocal/profiles.go index f831a40931..30f4c59f84 100644 --- a/vendor/tailscale.com/ipn/ipnlocal/profiles.go +++ b/vendor/tailscale.com/ipn/ipnlocal/profiles.go @@ -10,15 +10,15 @@ import ( "math/rand" "net/netip" "runtime" + "slices" "strings" "time" - "golang.org/x/exp/slices" "tailscale.com/envknob" "tailscale.com/ipn" - "tailscale.com/tailcfg" "tailscale.com/types/logger" "tailscale.com/util/clientmetric" + "tailscale.com/util/cmpx" "tailscale.com/util/winutil" ) @@ -28,20 +28,16 @@ var debug = envknob.RegisterBool("TS_DEBUG_PROFILES") // profileManager is a wrapper around a StateStore that manages // multiple profiles and the current profile. +// +// It is not safe for concurrent use. type profileManager struct { store ipn.StateStore logf logger.Logf currentUserID ipn.WindowsUserID - knownProfiles map[ipn.ProfileID]*ipn.LoginProfile - currentProfile *ipn.LoginProfile // always non-nil - prefs ipn.PrefsView // always Valid. - - // isNewProfile is a sentinel value that indicates that the - // current profile is new and has not been saved to disk yet. - // It is reset to false after a call to SetPrefs with a filled - // in LoginName. - isNewProfile bool + knownProfiles map[ipn.ProfileID]*ipn.LoginProfile // always non-nil + currentProfile *ipn.LoginProfile // always non-nil + prefs ipn.PrefsView // always Valid. } func (pm *profileManager) dlogf(format string, args ...any) { @@ -51,6 +47,10 @@ func (pm *profileManager) dlogf(format string, args ...any) { pm.logf(format, args...) } +func (pm *profileManager) WriteState(id ipn.StateKey, val []byte) error { + return ipn.WriteState(pm.store, id, val) +} + // CurrentUserID returns the current user ID. It is only non-empty on // Windows where we have a multi-user system. func (pm *profileManager) CurrentUserID() ipn.WindowsUserID { @@ -103,40 +103,45 @@ func (pm *profileManager) SetCurrentUserID(uid ipn.WindowsUserID) error { } pm.currentProfile = prof pm.prefs = prefs - pm.isNewProfile = false return nil } -// matchingProfiles returns all profiles that match the given predicate and -// belong to the currentUserID. -func (pm *profileManager) matchingProfiles(f func(*ipn.LoginProfile) bool) (out []*ipn.LoginProfile) { +// allProfiles returns all profiles that belong to the currentUserID. +// The returned profiles are sorted by Name. +func (pm *profileManager) allProfiles() (out []*ipn.LoginProfile) { for _, p := range pm.knownProfiles { - if p.LocalUserID == pm.currentUserID && f(p) { + if p.LocalUserID == pm.currentUserID { out = append(out, p) } } + slices.SortFunc(out, func(a, b *ipn.LoginProfile) int { + return cmpx.Compare(a.Name, b.Name) + }) return out } -// findProfilesByNodeID returns all profiles that have the provided nodeID and -// belong to the same control server. -func (pm *profileManager) findProfilesByNodeID(controlURL string, nodeID tailcfg.StableNodeID) []*ipn.LoginProfile { - if nodeID.IsZero() { - return nil +// matchingProfiles returns all profiles that match the given predicate and +// belong to the currentUserID. +// The returned profiles are sorted by Name. +func (pm *profileManager) matchingProfiles(f func(*ipn.LoginProfile) bool) (out []*ipn.LoginProfile) { + all := pm.allProfiles() + out = all[:0] + for _, p := range all { + if f(p) { + out = append(out, p) + } } - return pm.matchingProfiles(func(p *ipn.LoginProfile) bool { - return p.NodeID == nodeID && p.ControlURL == controlURL - }) + return out } -// findProfilesByUserID returns all profiles that have the provided userID and -// belong to the same control server. -func (pm *profileManager) findProfilesByUserID(controlURL string, userID tailcfg.UserID) []*ipn.LoginProfile { - if userID.IsZero() { - return nil - } +// findMatchinProfiles returns all profiles that represent the same node/user as +// prefs. +// The returned profiles are sorted by Name. +func (pm *profileManager) findMatchingProfiles(prefs *ipn.Prefs) []*ipn.LoginProfile { return pm.matchingProfiles(func(p *ipn.LoginProfile) bool { - return p.UserProfile.ID == userID && p.ControlURL == controlURL + return p.ControlURL == prefs.ControlURL && + (p.UserProfile.ID == prefs.Persist.UserProfile.ID || + p.NodeID == prefs.Persist.NodeID) }) } @@ -182,9 +187,9 @@ func (pm *profileManager) setUnattendedModeAsConfigured() error { } if pm.prefs.ForceDaemon() { - return pm.store.WriteState(ipn.ServerModeStartKey, []byte(pm.currentProfile.Key)) + return pm.WriteState(ipn.ServerModeStartKey, []byte(pm.currentProfile.Key)) } else { - return pm.store.WriteState(ipn.ServerModeStartKey, nil) + return pm.WriteState(ipn.ServerModeStartKey, nil) } } @@ -201,49 +206,58 @@ func init() { // SetPrefs sets the current profile's prefs to the provided value. // It also saves the prefs to the StateStore. It stores a copy of the // provided prefs, which may be accessed via CurrentPrefs. -func (pm *profileManager) SetPrefs(prefsIn ipn.PrefsView) error { - prefs := prefsIn.AsStruct().View() - newPersist := prefs.Persist().AsStruct() - if newPersist == nil || newPersist.NodeID == "" { - return pm.setPrefsLocked(prefs) +// +// If tailnetMagicDNSName is provided non-empty, it will be used to +// enrich the profile with the tailnet's MagicDNS name. The MagicDNS +// name cannot be pulled from prefsIn directly because it is not saved +// on ipn.Prefs (since it's not a field that is configurable by nodes). +func (pm *profileManager) SetPrefs(prefsIn ipn.PrefsView, tailnetMagicDNSName string) error { + prefs := prefsIn.AsStruct() + newPersist := prefs.Persist + if newPersist == nil || newPersist.NodeID == "" || newPersist.UserProfile.LoginName == "" { + // We don't know anything about this profile, so ignore it for now. + return pm.setPrefsLocked(prefs.View()) } up := newPersist.UserProfile - if up.LoginName == "" { - // Backwards compatibility with old prefs files. - up.LoginName = newPersist.LoginName - } else { - newPersist.LoginName = up.LoginName - } if up.DisplayName == "" { up.DisplayName = up.LoginName } cp := pm.currentProfile - if pm.isNewProfile { - pm.isNewProfile = false - // Check if we already have a profile for this user. - existing := pm.findProfilesByUserID(prefs.ControlURL(), newPersist.UserProfile.ID) - // Also check if we have a profile with the same NodeID. - existing = append(existing, pm.findProfilesByNodeID(prefs.ControlURL(), newPersist.NodeID)...) - if len(existing) == 0 { - cp.ID, cp.Key = newUnusedID(pm.knownProfiles) - } else { - // Only one profile per user/nodeID should exist. - for _, p := range existing[1:] { - // Best effort cleanup. - pm.DeleteProfile(p.ID) + // Check if we already have an existing profile that matches the user/node. + if existing := pm.findMatchingProfiles(prefs); len(existing) > 0 { + // We already have a profile for this user/node we should reuse it. Also + // cleanup any other duplicate profiles. + cp = existing[0] + existing = existing[1:] + for _, p := range existing { + // Clear the state. + if err := pm.store.WriteState(p.Key, nil); err != nil { + // We couldn't delete the state, so keep the profile around. + continue } - cp = existing[0] + // Remove the profile, knownProfiles will be persisted below. + delete(pm.knownProfiles, p.ID) } + } else if cp.ID == "" { + // We didn't have an existing profile, so create a new one. + cp.ID, cp.Key = newUnusedID(pm.knownProfiles) cp.LocalUserID = pm.currentUserID + } else { + // This means that there was a force-reauth as a new node that + // we haven't seen before. } - if prefs.ProfileName() != "" { - cp.Name = prefs.ProfileName() + + if prefs.ProfileName != "" { + cp.Name = prefs.ProfileName } else { cp.Name = up.LoginName } - cp.ControlURL = prefs.ControlURL() + cp.ControlURL = prefs.ControlURL cp.UserProfile = newPersist.UserProfile cp.NodeID = newPersist.NodeID + if tailnetMagicDNSName != "" { + cp.TailnetMagicDNSName = tailnetMagicDNSName + } pm.knownProfiles[cp.ID] = cp pm.currentProfile = cp if err := pm.writeKnownProfiles(); err != nil { @@ -252,7 +266,7 @@ func (pm *profileManager) SetPrefs(prefsIn ipn.PrefsView) error { if err := pm.setAsUserSelectedProfileLocked(); err != nil { return err } - if err := pm.setPrefsLocked(prefs); err != nil { + if err := pm.setPrefsLocked(prefs.View()); err != nil { return err } return nil @@ -275,7 +289,7 @@ func newUnusedID(knownProfiles map[ipn.ProfileID]*ipn.LoginProfile) (ipn.Profile // is not new. func (pm *profileManager) setPrefsLocked(clonedPrefs ipn.PrefsView) error { pm.prefs = clonedPrefs - if pm.isNewProfile { + if pm.currentProfile.ID == "" { return nil } if err := pm.writePrefsToStore(pm.currentProfile.Key, pm.prefs); err != nil { @@ -288,7 +302,7 @@ func (pm *profileManager) writePrefsToStore(key ipn.StateKey, prefs ipn.PrefsVie if key == "" { return nil } - if err := pm.store.WriteState(key, prefs.ToBytes()); err != nil { + if err := pm.WriteState(key, prefs.ToBytes()); err != nil { pm.logf("WriteState(%q): %v", key, err) return err } @@ -297,12 +311,9 @@ func (pm *profileManager) writePrefsToStore(key ipn.StateKey, prefs ipn.PrefsVie // Profiles returns the list of known profiles. func (pm *profileManager) Profiles() []ipn.LoginProfile { - profiles := pm.matchingProfiles(func(*ipn.LoginProfile) bool { return true }) - slices.SortFunc(profiles, func(a, b *ipn.LoginProfile) bool { - return a.Name < b.Name - }) - out := make([]ipn.LoginProfile, 0, len(profiles)) - for _, p := range profiles { + allProfiles := pm.allProfiles() + out := make([]ipn.LoginProfile, 0, len(allProfiles)) + for _, p := range allProfiles { out = append(out, *p) } return out @@ -330,13 +341,12 @@ func (pm *profileManager) SwitchProfile(id ipn.ProfileID) error { } pm.prefs = prefs pm.currentProfile = kp - pm.isNewProfile = false return pm.setAsUserSelectedProfileLocked() } func (pm *profileManager) setAsUserSelectedProfileLocked() error { k := ipn.CurrentProfileKey(string(pm.currentUserID)) - return pm.store.WriteState(k, []byte(pm.currentProfile.Key)) + return pm.WriteState(k, []byte(pm.currentProfile.Key)) } func (pm *profileManager) loadSavedPrefs(key ipn.StateKey) (ipn.PrefsView, error) { @@ -382,7 +392,7 @@ var errProfileNotFound = errors.New("profile not found") func (pm *profileManager) DeleteProfile(id ipn.ProfileID) error { metricDeleteProfile.Add(1) - if id == "" && pm.isNewProfile { + if id == "" { // Deleting the in-memory only new profile, just create a new one. pm.NewProfile() return nil @@ -394,7 +404,7 @@ func (pm *profileManager) DeleteProfile(id ipn.ProfileID) error { if kp.ID == pm.currentProfile.ID { pm.NewProfile() } - if err := pm.store.WriteState(kp.Key, nil); err != nil { + if err := pm.WriteState(kp.Key, nil); err != nil { return err } delete(pm.knownProfiles, id) @@ -407,7 +417,7 @@ func (pm *profileManager) DeleteAllProfiles() error { metricDeleteAllProfile.Add(1) for _, kp := range pm.knownProfiles { - if err := pm.store.WriteState(kp.Key, nil); err != nil { + if err := pm.WriteState(kp.Key, nil); err != nil { // Write to remove references to profiles we've already deleted, but // return the original error. pm.writeKnownProfiles() @@ -424,7 +434,7 @@ func (pm *profileManager) writeKnownProfiles() error { if err != nil { return err } - return pm.store.WriteState(ipn.KnownProfilesStateKey, b) + return pm.WriteState(ipn.KnownProfilesStateKey, b) } // NewProfile creates and switches to a new unnamed profile. The new profile is @@ -433,13 +443,13 @@ func (pm *profileManager) NewProfile() { metricNewProfile.Add(1) pm.prefs = defaultPrefs - pm.isNewProfile = true pm.currentProfile = &ipn.LoginProfile{} } // defaultPrefs is the default prefs for a new profile. var defaultPrefs = func() ipn.PrefsView { prefs := ipn.NewPrefs() + prefs.LoggedOut = true prefs.WantRunning = false prefs.ControlURL = winutil.GetPolicyString("LoginURL", "") @@ -587,7 +597,7 @@ func (pm *profileManager) migrateFromLegacyPrefs() error { return fmt.Errorf("load legacy prefs: %w", err) } pm.dlogf("loaded legacy preferences; sentinel=%q", sentinel) - if err := pm.SetPrefs(prefs); err != nil { + if err := pm.SetPrefs(prefs, ""); err != nil { metricMigrationError.Add(1) return fmt.Errorf("migrating _daemon profile: %w", err) } diff --git a/vendor/tailscale.com/ipn/ipnlocal/profiles_notwindows.go b/vendor/tailscale.com/ipn/ipnlocal/profiles_notwindows.go index 5e045d5f2c..fc61d26713 100644 --- a/vendor/tailscale.com/ipn/ipnlocal/profiles_notwindows.go +++ b/vendor/tailscale.com/ipn/ipnlocal/profiles_notwindows.go @@ -16,9 +16,7 @@ import ( func (pm *profileManager) loadLegacyPrefs() (string, ipn.PrefsView, error) { k := ipn.LegacyGlobalDaemonStateKey switch { - case runtime.GOOS == "ios": - k = "ipn-go-bridge" - case version.IsSandboxedMacOS(): + case runtime.GOOS == "ios", version.IsSandboxedMacOS(): k = "ipn-go-bridge" case runtime.GOOS == "android": k = "ipn-android" diff --git a/vendor/tailscale.com/ipn/ipnlocal/serve.go b/vendor/tailscale.com/ipn/ipnlocal/serve.go index a6a7d1421f..c28ca02390 100644 --- a/vendor/tailscale.com/ipn/ipnlocal/serve.go +++ b/vendor/tailscale.com/ipn/ipnlocal/serve.go @@ -5,7 +5,9 @@ package ipnlocal import ( "context" + "crypto/sha256" "crypto/tls" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -17,12 +19,12 @@ import ( "net/url" "os" "path" + "slices" "strconv" "strings" "sync" "time" - "golang.org/x/exp/slices" "tailscale.com/ipn" "tailscale.com/logtail/backoff" "tailscale.com/net/netutil" @@ -33,6 +35,11 @@ import ( "tailscale.com/version" ) +// ErrETagMismatch signals that the given +// If-Match header does not match with the +// current etag of a resource. +var ErrETagMismatch = errors.New("etag mismatch") + // serveHTTPContextKey is the context.Value key for a *serveHTTPContext. type serveHTTPContextKey struct{} @@ -193,12 +200,14 @@ func (b *LocalBackend) updateServeTCPPortNetMapAddrListenersLocked(ports []uint1 b.logf("netMap is nil") return } - if nm.SelfNode == nil { + if !nm.SelfNode.Valid() { b.logf("netMap SelfNode is nil") return } - for _, a := range nm.Addresses { + addrs := nm.GetAddresses() + for i := range addrs.LenIter() { + a := addrs.At(i) for _, p := range ports { addrPort := netip.AddrPortFrom(a.Addr(), p) if _, ok := b.serveListeners[addrPort]; ok { @@ -214,10 +223,15 @@ func (b *LocalBackend) updateServeTCPPortNetMapAddrListenersLocked(ports []uint1 } // SetServeConfig establishes or replaces the current serve config. -func (b *LocalBackend) SetServeConfig(config *ipn.ServeConfig) error { +// ETag is an optional parameter to enforce Optimistic Concurrency Control. +// If it is an empty string, then the config will be overwritten. +func (b *LocalBackend) SetServeConfig(config *ipn.ServeConfig, etag string) error { b.mu.Lock() defer b.mu.Unlock() + return b.setServeConfigLocked(config, etag) +} +func (b *LocalBackend) setServeConfigLocked(config *ipn.ServeConfig, etag string) error { prefs := b.pm.CurrentPrefs() if config.IsFunnelOn() && prefs.ShieldsUp() { return errors.New("Unable to turn on Funnel while shields-up is enabled") @@ -227,11 +241,27 @@ func (b *LocalBackend) SetServeConfig(config *ipn.ServeConfig) error { if nm == nil { return errors.New("netMap is nil") } - if nm.SelfNode == nil { + if !nm.SelfNode.Valid() { return errors.New("netMap SelfNode is nil") } - profileID := b.pm.CurrentProfile().ID - confKey := ipn.ServeConfigKey(profileID) + + // If etag is present, check that it has + // not changed from the last config. + if etag != "" { + // Note that we marshal b.serveConfig + // and not use b.lastServeConfJSON as that might + // be a Go nil value, which produces a different + // checksum from a JSON "null" value. + previousCfg, err := json.Marshal(b.serveConfig) + if err != nil { + return fmt.Errorf("error encoding previous config: %w", err) + } + sum := sha256.Sum256(previousCfg) + previousEtag := hex.EncodeToString(sum[:]) + if etag != previousEtag { + return ErrETagMismatch + } + } var bs []byte if config != nil { @@ -241,6 +271,9 @@ func (b *LocalBackend) SetServeConfig(config *ipn.ServeConfig) error { } bs = j } + + profileID := b.pm.CurrentProfile().ID + confKey := ipn.ServeConfigKey(profileID) if err := b.store.WriteState(confKey, bs); err != nil { return fmt.Errorf("writing ServeConfig to StateStore: %w", err) } @@ -257,7 +290,21 @@ func (b *LocalBackend) ServeConfig() ipn.ServeConfigView { return b.serveConfig } -func (b *LocalBackend) HandleIngressTCPConn(ingressPeer *tailcfg.Node, target ipn.HostPort, srcAddr netip.AddrPort, getConnOrReset func() (net.Conn, bool), sendRST func()) { +// DeleteForegroundSession deletes a ServeConfig's foreground session +// in the LocalBackend if it exists. It also ensures check, delete, and +// set operations happen within the same mutex lock to avoid any races. +func (b *LocalBackend) DeleteForegroundSession(sessionID string) error { + b.mu.Lock() + defer b.mu.Unlock() + if !b.serveConfig.Valid() || !b.serveConfig.Foreground().Has(sessionID) { + return nil + } + sc := b.serveConfig.AsStruct() + delete(sc.Foreground, sessionID) + return b.setServeConfigLocked(sc, "") +} + +func (b *LocalBackend) HandleIngressTCPConn(ingressPeer tailcfg.NodeView, target ipn.HostPort, srcAddr netip.AddrPort, getConnOrReset func() (net.Conn, bool), sendRST func()) { b.mu.Lock() sc := b.serveConfig b.mu.Unlock() @@ -268,7 +315,7 @@ func (b *LocalBackend) HandleIngressTCPConn(ingressPeer *tailcfg.Node, target ip return } - if !sc.AllowFunnel().Get(target) { + if !sc.HasFunnelForTarget(target) { b.logf("localbackend: got ingress conn for unconfigured %q; rejecting", target) sendRST() return @@ -326,7 +373,7 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort) return nil } - tcph, ok := sc.TCP().GetOk(dport) + tcph, ok := sc.FindTCP(dport) if !ok { b.logf("[unexpected] localbackend: got TCP conn without TCP config for port %v; from %v", dport, srcAddr) return nil @@ -372,7 +419,7 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort) GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - pair, err := b.GetCertPEM(ctx, sni) + pair, err := b.GetCertPEM(ctx, sni, false) if err != nil { return nil, err } @@ -415,6 +462,9 @@ func (b *LocalBackend) getServeHandler(r *http.Request) (_ ipn.HTTPHandlerView, hostname := r.Host if r.TLS == nil { tcd := "." + b.Status().CurrentTailnet.MagicDNSSuffix + if host, _, err := net.SplitHostPort(hostname); err == nil { + hostname = host + } if !strings.HasSuffix(hostname, tcd) { hostname += tcd } @@ -496,6 +546,7 @@ func (b *LocalBackend) addTailscaleIdentityHeaders(r *httputil.ProxyRequest) { // Clear any incoming values squatting in the headers. r.Out.Header.Del("Tailscale-User-Login") r.Out.Header.Del("Tailscale-User-Name") + r.Out.Header.Del("Tailscale-User-Profile-Pic") r.Out.Header.Del("Tailscale-Headers-Info") c, ok := getServeHTTPContext(r.Out) @@ -513,9 +564,12 @@ func (b *LocalBackend) addTailscaleIdentityHeaders(r *httputil.ProxyRequest) { } r.Out.Header.Set("Tailscale-User-Login", user.LoginName) r.Out.Header.Set("Tailscale-User-Name", user.DisplayName) + r.Out.Header.Set("Tailscale-User-Profile-Pic", user.ProfilePicURL) r.Out.Header.Set("Tailscale-Headers-Info", "https://tailscale.com/s/serve-headers") } +// serveWebHandler is an http.HandlerFunc that maps incoming requests to the +// correct *http. func (b *LocalBackend) serveWebHandler(w http.ResponseWriter, r *http.Request) { h, mountPoint, ok := b.getServeHandler(r) if !ok { @@ -657,7 +711,7 @@ func (b *LocalBackend) webServerConfig(hostname string, port uint16) (c ipn.WebS if !b.serveConfig.Valid() { return c, false } - return b.serveConfig.Web().GetOk(key) + return b.serveConfig.FindWeb(key) } func (b *LocalBackend) getTLSServeCertForPort(port uint16) func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { @@ -672,7 +726,7 @@ func (b *LocalBackend) getTLSServeCertForPort(port uint16) func(hi *tls.ClientHe ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - pair, err := b.GetCertPEM(ctx, hi.ServerName) + pair, err := b.GetCertPEM(ctx, hi.ServerName, false) if err != nil { return nil, err } diff --git a/vendor/tailscale.com/ipn/ipnlocal/ssh.go b/vendor/tailscale.com/ipn/ipnlocal/ssh.go index 6fa56d7d9f..19a23c030a 100644 --- a/vendor/tailscale.com/ipn/ipnlocal/ssh.go +++ b/vendor/tailscale.com/ipn/ipnlocal/ssh.go @@ -20,12 +20,12 @@ import ( "os/exec" "path/filepath" "runtime" + "slices" "strings" "sync" "github.com/tailscale/golang-x-crypto/ssh" "go4.org/mem" - "golang.org/x/exp/slices" "tailscale.com/tailcfg" "tailscale.com/util/lineread" "tailscale.com/util/mak" diff --git a/vendor/tailscale.com/ipn/ipnstate/ipnstate.go b/vendor/tailscale.com/ipn/ipnstate/ipnstate.go index 3a464c9be1..03b41d6a91 100644 --- a/vendor/tailscale.com/ipn/ipnstate/ipnstate.go +++ b/vendor/tailscale.com/ipn/ipnstate/ipnstate.go @@ -12,9 +12,9 @@ import ( "io" "log" "net/netip" + "slices" "sort" "strings" - "sync" "time" "tailscale.com/tailcfg" @@ -69,8 +69,17 @@ type Status struct { // trailing periods, and without any "_acme-challenge." prefix. CertDomains []string + // Peer is the state of each peer, keyed by each peer's current public key. Peer map[key.NodePublic]*PeerStatus + + // User contains profile information about UserIDs referenced by + // PeerStatus.UserID, PeerStatus.AltSharerUserID, etc. User map[tailcfg.UserID]tailcfg.UserProfile + + // ClientVersion, when non-nil, contains information about the latest + // version of the Tailscale client that's available. Depending on + // the platform and client settings, it may not be available. + ClientVersion *tailcfg.ClientVersion } // TKAKey describes a key trusted by network lock. @@ -188,6 +197,7 @@ type PeerStatusLite struct { NodeKey key.NodePublic } +// PeerStatus describes a peer node and its current state. type PeerStatus struct { ID tailcfg.StableNodeID PublicKey key.NodePublic @@ -199,6 +209,10 @@ type PeerStatus struct { OS string // HostInfo.OS UserID tailcfg.UserID + // AltSharerUserID is the user who shared this node + // if it's different than UserID. Otherwise it's zero. + AltSharerUserID tailcfg.UserID `json:",omitempty"` + // TailscaleIPs are the IP addresses assigned to the node. TailscaleIPs []netip.Addr @@ -209,7 +223,7 @@ type PeerStatus struct { // PrimaryRoutes are the routes this node is currently the primary // subnet router for, as determined by the control plane. It does // not include the IPs in TailscaleIPs. - PrimaryRoutes *views.IPPrefixSlice `json:",omitempty"` + PrimaryRoutes *views.Slice[netip.Prefix] `json:",omitempty"` // Endpoints: Addrs []string @@ -223,9 +237,8 @@ type PeerStatus struct { LastSeen time.Time // last seen to tailcontrol; only present if offline LastHandshake time.Time // with local wireguard Online bool // whether node is connected to the control plane - KeepAlive bool - ExitNode bool // true if this is the currently selected exit node. - ExitNodeOption bool // true if this node can be an exit node (offered && approved) + ExitNode bool // true if this is the currently selected exit node. + ExitNodeOption bool // true if this node can be an exit node (offered && approved) // Active is whether the node was recently active. The // definition is somewhat undefined but has historically and @@ -243,7 +256,10 @@ type PeerStatus struct { // "https://tailscale.com/cap/is-admin" // "https://tailscale.com/cap/file-sharing" // "funnel" - Capabilities []string `json:",omitempty"` + Capabilities []tailcfg.NodeCapability `json:",omitempty"` + + // CapMap is a map of capabilities to their values. + CapMap tailcfg.NodeCapMap `json:",omitempty"` // SSH_HostKeys are the node's SSH host keys, if known. SSH_HostKeys []string `json:"sshHostKeys,omitempty"` @@ -274,12 +290,21 @@ type PeerStatus struct { // KeyExpiry, if present, is the time at which the node key expired or // will expire. KeyExpiry *time.Time `json:",omitempty"` + + Location *tailcfg.Location `json:",omitempty"` +} + +// HasCap reports whether ps has the given capability. +func (ps *PeerStatus) HasCap(cap tailcfg.NodeCapability) bool { + return ps.CapMap.Contains(cap) || slices.Contains(ps.Capabilities, cap) } +// StatusBuilder is a request to construct a Status. A new StatusBuilder is +// passed to various subsystems which then call methods on it to populate state. +// Call its Status method to return the final constructed Status. type StatusBuilder struct { WantPeers bool // whether caller wants peers - mu sync.Mutex locked bool st Status } @@ -288,17 +313,13 @@ type StatusBuilder struct { // // It may not assume other fields of status are already populated, and // may not retain or write to the Status after f returns. -// -// MutateStatus acquires a lock so f must not call back into sb. func (sb *StatusBuilder) MutateStatus(f func(*Status)) { - sb.mu.Lock() - defer sb.mu.Unlock() f(&sb.st) } +// Status returns the status that has been built up so far from previous +// calls to MutateStatus, MutateSelfStatus, AddPeer, etc. func (sb *StatusBuilder) Status() *Status { - sb.mu.Lock() - defer sb.mu.Unlock() sb.locked = true return &sb.st } @@ -310,8 +331,6 @@ func (sb *StatusBuilder) Status() *Status { // // MutateStatus acquires a lock so f must not call back into sb. func (sb *StatusBuilder) MutateSelfStatus(f func(*PeerStatus)) { - sb.mu.Lock() - defer sb.mu.Unlock() if sb.st.Self == nil { sb.st.Self = new(PeerStatus) } @@ -320,8 +339,6 @@ func (sb *StatusBuilder) MutateSelfStatus(f func(*PeerStatus)) { // AddUser adds a user profile to the status. func (sb *StatusBuilder) AddUser(id tailcfg.UserID, up tailcfg.UserProfile) { - sb.mu.Lock() - defer sb.mu.Unlock() if sb.locked { log.Printf("[unexpected] ipnstate: AddUser after Locked") return @@ -336,8 +353,6 @@ func (sb *StatusBuilder) AddUser(id tailcfg.UserID, up tailcfg.UserProfile) { // AddIP adds a Tailscale IP address to the status. func (sb *StatusBuilder) AddTailscaleIP(ip netip.Addr) { - sb.mu.Lock() - defer sb.mu.Unlock() if sb.locked { log.Printf("[unexpected] ipnstate: AddIP after Locked") return @@ -354,8 +369,6 @@ func (sb *StatusBuilder) AddPeer(peer key.NodePublic, st *PeerStatus) { panic("nil PeerStatus") } - sb.mu.Lock() - defer sb.mu.Unlock() if sb.locked { log.Printf("[unexpected] ipnstate: AddPeer after Locked") return @@ -386,6 +399,9 @@ func (sb *StatusBuilder) AddPeer(peer key.NodePublic, st *PeerStatus) { if v := st.UserID; v != 0 { e.UserID = v } + if v := st.AltSharerUserID; v != 0 { + e.AltSharerUserID = v + } if v := st.TailscaleIPs; v != nil { e.TailscaleIPs = v } @@ -437,9 +453,6 @@ func (sb *StatusBuilder) AddPeer(peer key.NodePublic, st *PeerStatus) { if st.InEngine { e.InEngine = true } - if st.KeepAlive { - e.KeepAlive = true - } if st.ExitNode { e.ExitNode = true } @@ -461,6 +474,7 @@ func (sb *StatusBuilder) AddPeer(peer key.NodePublic, st *PeerStatus) { if t := st.KeyExpiry; t != nil { e.KeyExpiry = ptr.To(*t) } + e.Location = st.Location } type StatusUpdater interface { @@ -659,23 +673,29 @@ func (pr *PingResult) ToPingResponse(pingType tailcfg.PingType) *tailcfg.PingRes } } +// SortPeers sorts peers by either their DNS name, hostname, Tailscale IP, +// or ultimately their current public key. func SortPeers(peers []*PeerStatus) { - sort.Slice(peers, func(i, j int) bool { return sortKey(peers[i]) < sortKey(peers[j]) }) + slices.SortStableFunc(peers, (*PeerStatus).compare) } -func sortKey(ps *PeerStatus) string { - if ps.DNSName != "" { - return ps.DNSName +func (a *PeerStatus) compare(b *PeerStatus) int { + if a.DNSName != "" || b.DNSName != "" { + if v := strings.Compare(a.DNSName, b.DNSName); v != 0 { + return v + } } - if ps.HostName != "" { - return ps.HostName + if a.HostName != "" || b.HostName != "" { + if v := strings.Compare(a.HostName, b.HostName); v != 0 { + return v + } } - // TODO(bradfitz): add PeerStatus.Less and avoid these allocs in a Less func. - if len(ps.TailscaleIPs) > 0 { - return ps.TailscaleIPs[0].String() + if len(a.TailscaleIPs) > 0 && len(b.TailscaleIPs) > 0 { + if v := a.TailscaleIPs[0].Compare(b.TailscaleIPs[0]); v != 0 { + return v + } } - raw := ps.PublicKey.Raw32() - return string(raw[:]) + return a.PublicKey.Compare(b.PublicKey) } // DebugDERPRegionReport is the result of a "tailscale debug derp" command, diff --git a/vendor/tailscale.com/ipn/localapi/cert.go b/vendor/tailscale.com/ipn/localapi/cert.go index 447c3bc3cb..e1704cb499 100644 --- a/vendor/tailscale.com/ipn/localapi/cert.go +++ b/vendor/tailscale.com/ipn/localapi/cert.go @@ -23,7 +23,7 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) { http.Error(w, "internal handler config wired wrong", 500) return } - pair, err := h.b.GetCertPEM(r.Context(), domain) + pair, err := h.b.GetCertPEM(r.Context(), domain, true) if err != nil { // TODO(bradfitz): 500 is a little lazy here. The errors returned from // GetCertPEM (and everywhere) should carry info info to get whether diff --git a/vendor/tailscale.com/ipn/localapi/localapi.go b/vendor/tailscale.com/ipn/localapi/localapi.go index d99dcad883..ec5b1cd793 100644 --- a/vendor/tailscale.com/ipn/localapi/localapi.go +++ b/vendor/tailscale.com/ipn/localapi/localapi.go @@ -7,25 +7,24 @@ package localapi import ( "bytes" "context" - "crypto/rand" + "crypto/sha256" "encoding/hex" "encoding/json" "errors" "fmt" "io" - "io/ioutil" "net" "net/http" "net/http/httputil" "net/netip" "net/url" "runtime" + "slices" "strconv" "strings" "sync" "time" - "golang.org/x/exp/slices" "tailscale.com/client/tailscale/apitype" "tailscale.com/envknob" "tailscale.com/health" @@ -37,15 +36,20 @@ import ( "tailscale.com/net/netmon" "tailscale.com/net/netutil" "tailscale.com/net/portmapper" + "tailscale.com/net/tstun" "tailscale.com/tailcfg" "tailscale.com/tka" + "tailscale.com/tstime" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/ptr" + "tailscale.com/types/tkatype" "tailscale.com/util/clientmetric" "tailscale.com/util/httpm" "tailscale.com/util/mak" + "tailscale.com/util/osdiag" + "tailscale.com/util/rands" "tailscale.com/version" ) @@ -105,15 +109,13 @@ var handler = map[string]localAPIHandler{ "tka/affected-sigs": (*Handler).serveTKAAffectedSigs, "tka/wrap-preauth-key": (*Handler).serveTKAWrapPreauthKey, "tka/verify-deeplink": (*Handler).serveTKAVerifySigningDeeplink, + "tka/generate-recovery-aum": (*Handler).serveTKAGenerateRecoveryAUM, + "tka/cosign-recovery-aum": (*Handler).serveTKACosignRecoveryAUM, + "tka/submit-recovery-aum": (*Handler).serveTKASubmitRecoveryAUM, "upload-client-metrics": (*Handler).serveUploadClientMetrics, "watch-ipn-bus": (*Handler).serveWatchIPNBus, "whois": (*Handler).serveWhoIs, -} - -func randHex(n int) string { - b := make([]byte, n) - rand.Read(b) - return hex.EncodeToString(b) + "query-feature": (*Handler).serveQueryFeature, } var ( @@ -129,7 +131,7 @@ var ( // NewHandler creates a new LocalAPI HTTP handler. All parameters except netMon // are required (if non-nil it's used to do faster interface lookups). func NewHandler(b *ipnlocal.LocalBackend, logf logger.Logf, netMon *netmon.Monitor, logID logid.PublicID) *Handler { - return &Handler{b: b, logf: logf, netMon: netMon, backendLogID: logID} + return &Handler{b: b, logf: logf, netMon: netMon, backendLogID: logID, clock: tstime.StdClock{}} } type Handler struct { @@ -155,6 +157,7 @@ type Handler struct { logf logger.Logf netMon *netmon.Monitor // optional; nil means interfaces will be looked up on-demand backendLogID logid.PublicID + clock tstime.Clock } func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -309,7 +312,7 @@ func (h *Handler) serveBugReport(w http.ResponseWriter, r *http.Request) { defer h.b.TryFlushLogs() // kick off upload after bugreport's done logging logMarker := func() string { - return fmt.Sprintf("BUG-%v-%v-%v", h.backendLogID, time.Now().UTC().Format("20060102150405Z"), randHex(8)) + return fmt.Sprintf("BUG-%v-%v-%v", h.backendLogID, h.clock.Now().UTC().Format("20060102150405Z"), rands.HexString(16)) } if envknob.NoLogsNoSupport() { logMarker = func() string { return "BUG-NO-LOGS-NO-SUPPORT-this-node-has-had-its-logging-disabled" } @@ -330,8 +333,8 @@ func (h *Handler) serveBugReport(w http.ResponseWriter, r *http.Request) { // Information about the current node from the netmap if nm := h.b.NetMap(); nm != nil { - if self := nm.SelfNode; self != nil { - h.logf("user bugreport node info: nodeid=%q stableid=%q expiry=%q", self.ID, self.StableID, self.KeyExpiry.Format(time.RFC3339)) + if self := nm.SelfNode; self.Valid() { + h.logf("user bugreport node info: nodeid=%q stableid=%q expiry=%q", self.ID(), self.StableID(), self.KeyExpiry().Format(time.RFC3339)) } h.logf("user bugreport public keys: machine=%q node=%q", nm.MachineKey, nm.NodeKey) } else { @@ -343,6 +346,9 @@ func (h *Handler) serveBugReport(w http.ResponseWriter, r *http.Request) { // logs for them. envknob.LogCurrent(logger.WithPrefix(h.logf, "user bugreport: ")) + // OS-specific details + osdiag.LogSupportInfo(logger.WithPrefix(h.logf, "user bugreport OS: "), osdiag.LogSupportInfoReasonBugReport) + if defBool(r.URL.Query().Get("diagnose"), false) { h.b.Doctor(r.Context(), logger.WithPrefix(h.logf, "diag: ")) } @@ -355,7 +361,7 @@ func (h *Handler) serveBugReport(w http.ResponseWriter, r *http.Request) { return } - until := time.Now().Add(12 * time.Hour) + until := h.clock.Now().Add(12 * time.Hour) var changed map[string]bool for _, component := range []string{"magicsock"} { @@ -425,9 +431,9 @@ func (h *Handler) serveWhoIs(w http.ResponseWriter, r *http.Request) { return } res := &apitype.WhoIsResponse{ - Node: n, - UserProfile: &u, - Caps: b.PeerCaps(ipp.Addr()), + Node: n.AsStruct(), // always non-nil per WhoIsResponse contract + UserProfile: &u, // always non-nil per WhoIsResponse contract + CapMap: b.PeerCaps(ipp.Addr()), } j, err := json.MarshalIndent(res, "", "\t") if err != nil { @@ -547,7 +553,19 @@ func (h *Handler) serveDebug(w http.ResponseWriter, r *http.Request) { break } h.b.DebugNotify(n) - + case "break-tcp-conns": + err = h.b.DebugBreakTCPConns() + case "break-derp-conns": + err = h.b.DebugBreakDERPConns() + case "force-netmap-update": + h.b.DebugForceNetmapUpdate() + case "control-knobs": + k := h.b.ControlKnobs() + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(k.AsDebugJSON()) + if err == nil { + return + } case "": err = fmt.Errorf("missing parameter 'action'") default: @@ -645,6 +663,10 @@ func (h *Handler) serveDebugPortmap(w http.ResponseWriter, r *http.Request) { return } + if defBool(r.FormValue("log_http"), false) { + debugKnobs.LogHTTP = true + } + var ( logLock sync.Mutex handlerDone bool @@ -683,7 +705,7 @@ func (h *Handler) serveDebugPortmap(w http.ResponseWriter, r *http.Request) { done := make(chan bool, 1) var c *portmapper.Client - c = portmapper.NewClient(logger.WithPrefix(logf, "portmapper: "), h.netMon, debugKnobs, func() { + c = portmapper.NewClient(logger.WithPrefix(logf, "portmapper: "), h.netMon, debugKnobs, h.b.ControlKnobs(), func() { logf("portmapping changed.") logf("have mapping: %v", c.HaveMapping()) @@ -766,7 +788,7 @@ func (h *Handler) serveComponentDebugLogging(w http.ResponseWriter, r *http.Requ } component := r.FormValue("component") secs, _ := strconv.Atoi(r.FormValue("secs")) - err := h.b.SetComponentDebugLogging(component, time.Now().Add(time.Duration(secs)*time.Second)) + err := h.b.SetComponentDebugLogging(component, h.clock.Now().Add(time.Duration(secs)*time.Second)) var res struct { Error string } @@ -819,9 +841,17 @@ func (h *Handler) serveServeConfig(w http.ResponseWriter, r *http.Request) { http.Error(w, "serve config denied", http.StatusForbidden) return } - w.Header().Set("Content-Type", "application/json") config := h.b.ServeConfig() - json.NewEncoder(w).Encode(config) + bts, err := json.Marshal(config) + if err != nil { + http.Error(w, "error encoding config: "+err.Error(), http.StatusInternalServerError) + return + } + sum := sha256.Sum256(bts) + etag := hex.EncodeToString(sum[:]) + w.Header().Set("Etag", etag) + w.Header().Set("Content-Type", "application/json") + w.Write(bts) case "POST": if !h.PermitWrite { http.Error(w, "serve config denied", http.StatusForbidden) @@ -832,7 +862,12 @@ func (h *Handler) serveServeConfig(w http.ResponseWriter, r *http.Request) { writeErrorJSON(w, fmt.Errorf("decoding config: %w", err)) return } - if err := h.b.SetServeConfig(configIn); err != nil { + etag := r.Header.Get("If-Match") + if err := h.b.SetServeConfig(configIn, etag); err != nil { + if errors.Is(err, ipnlocal.ErrETagMismatch) { + http.Error(w, err.Error(), http.StatusPreconditionFailed) + return + } writeErrorJSON(w, fmt.Errorf("updating config: %w", err)) return } @@ -1012,7 +1047,7 @@ func (h *Handler) serveLogout(w http.ResponseWriter, r *http.Request) { http.Error(w, "want POST", 400) return } - err := h.b.LogoutSync(r.Context()) + err := h.b.Logout(r.Context()) if err == nil { w.WriteHeader(http.StatusNoContent) return @@ -1331,11 +1366,28 @@ func (h *Handler) servePing(w http.ResponseWriter, r *http.Request) { return } pingTypeStr := r.FormValue("type") - if ipStr == "" { + if pingTypeStr == "" { http.Error(w, "missing 'type' parameter", 400) return } - res, err := h.b.Ping(ctx, ip, tailcfg.PingType(pingTypeStr)) + size := 0 + sizeStr := r.FormValue("size") + if sizeStr != "" { + size, err = strconv.Atoi(sizeStr) + if err != nil { + http.Error(w, "invalid 'size' parameter", 400) + return + } + if size != 0 && tailcfg.PingType(pingTypeStr) != tailcfg.PingDisco { + http.Error(w, "'size' parameter is only supported with disco pings", 400) + return + } + if size > int(tstun.DefaultMTU()) { + http.Error(w, fmt.Sprintf("maximum value for 'size' is %v", tstun.DefaultMTU()), 400) + return + } + } + res, err := h.b.Ping(ctx, ip, tailcfg.PingType(pingTypeStr), size) if err != nil { writeErrorJSON(w, err) return @@ -1426,10 +1478,9 @@ func (h *Handler) serveUploadClientMetrics(w http.ResponseWriter, r *http.Reques return } type clientMetricJSON struct { - Name string `json:"name"` - // One of "counter" or "gauge" - Type string `json:"type"` - Value int `json:"value"` + Name string `json:"name"` + Type string `json:"type"` // one of "counter" or "gauge" + Value int `json:"value"` // amount to increment metric by } var clientMetrics []clientMetricJSON @@ -1651,7 +1702,7 @@ func (h *Handler) serveTKADisable(w http.ResponseWriter, r *http.Request) { } body := io.LimitReader(r.Body, 1024*1024) - secret, err := ioutil.ReadAll(body) + secret, err := io.ReadAll(body) if err != nil { http.Error(w, "reading secret", 400) return @@ -1724,7 +1775,7 @@ func (h *Handler) serveTKAAffectedSigs(w http.ResponseWriter, r *http.Request) { http.Error(w, "use POST", http.StatusMethodNotAllowed) return } - keyID, err := ioutil.ReadAll(http.MaxBytesReader(w, r.Body, 2048)) + keyID, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 2048)) if err != nil { http.Error(w, "reading body", http.StatusBadRequest) return @@ -1745,6 +1796,103 @@ func (h *Handler) serveTKAAffectedSigs(w http.ResponseWriter, r *http.Request) { w.Write(j) } +func (h *Handler) serveTKAGenerateRecoveryAUM(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "use POST", http.StatusMethodNotAllowed) + return + } + + type verifyRequest struct { + Keys []tkatype.KeyID + ForkFrom string + } + var req verifyRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid JSON for verifyRequest body", http.StatusBadRequest) + return + } + + var forkFrom tka.AUMHash + if req.ForkFrom != "" { + if err := forkFrom.UnmarshalText([]byte(req.ForkFrom)); err != nil { + http.Error(w, "decoding fork-from: "+err.Error(), http.StatusBadRequest) + return + } + } + + res, err := h.b.NetworkLockGenerateRecoveryAUM(req.Keys, forkFrom) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + w.Header().Set("Content-Type", "application/octet-stream") + w.Write(res.Serialize()) +} + +func (h *Handler) serveTKACosignRecoveryAUM(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "use POST", http.StatusMethodNotAllowed) + return + } + + body := io.LimitReader(r.Body, 1024*1024) + aumBytes, err := io.ReadAll(body) + if err != nil { + http.Error(w, "reading AUM", http.StatusBadRequest) + return + } + var aum tka.AUM + if err := aum.Unserialize(aumBytes); err != nil { + http.Error(w, "decoding AUM", http.StatusBadRequest) + return + } + + res, err := h.b.NetworkLockCosignRecoveryAUM(&aum) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/octet-stream") + w.Write(res.Serialize()) +} + +func (h *Handler) serveTKASubmitRecoveryAUM(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "use POST", http.StatusMethodNotAllowed) + return + } + + body := io.LimitReader(r.Body, 1024*1024) + aumBytes, err := io.ReadAll(body) + if err != nil { + http.Error(w, "reading AUM", http.StatusBadRequest) + return + } + var aum tka.AUM + if err := aum.Unserialize(aumBytes); err != nil { + http.Error(w, "decoding AUM", http.StatusBadRequest) + return + } + + if err := h.b.NetworkLockSubmitRecoveryAUM(&aum); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + // serveProfiles serves profile switching-related endpoints. Supported methods // and paths are: // - GET /profiles/: list all profiles (JSON-encoded array of ipn.LoginProfiles) @@ -1829,6 +1977,66 @@ func (h *Handler) serveProfiles(w http.ResponseWriter, r *http.Request) { } } +// serveQueryFeature makes a request to the "/machine/feature/query" +// Noise endpoint to get instructions on how to enable a feature, such as +// Funnel, for the node's tailnet. +// +// This request itself does not directly enable the feature on behalf of +// the node, but rather returns information that can be presented to the +// acting user about where/how to enable the feature. If relevant, this +// includes a control URL the user can visit to explicitly consent to +// using the feature. +// +// See tailcfg.QueryFeatureResponse for full response structure. +func (h *Handler) serveQueryFeature(w http.ResponseWriter, r *http.Request) { + feature := r.FormValue("feature") + switch { + case !h.PermitRead: + http.Error(w, "access denied", http.StatusForbidden) + return + case r.Method != httpm.POST: + http.Error(w, "use POST", http.StatusMethodNotAllowed) + return + case feature == "": + http.Error(w, "missing feature", http.StatusInternalServerError) + return + } + nm := h.b.NetMap() + if nm == nil { + http.Error(w, "no netmap", http.StatusServiceUnavailable) + return + } + + b, err := json.Marshal(&tailcfg.QueryFeatureRequest{ + NodeKey: nm.NodeKey, + Feature: feature, + }) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + req, err := http.NewRequestWithContext(r.Context(), + "POST", "https://unused/machine/feature/query", bytes.NewReader(b)) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + resp, err := h.b.DoNoiseRequest(req) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer resp.Body.Close() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(resp.StatusCode) + if _, err := io.Copy(w, resp.Body); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + func defBool(a string, def bool) bool { if a == "" { return def @@ -1887,7 +2095,7 @@ func (h *Handler) serveDebugLog(w http.ResponseWriter, r *http.Request) { // opting-out of rate limits. Limit ourselves to at most one message // per 20ms and a burst of 60 log lines, which should be fast enough to // not block for too long but slow enough that we can upload all lines. - logf = logger.SlowLoggerWithClock(r.Context(), logf, 20*time.Millisecond, 60, time.Now) + logf = logger.SlowLoggerWithClock(r.Context(), logf, 20*time.Millisecond, 60, h.clock.Now) for _, line := range logRequest.Lines { logf("%s", line) diff --git a/vendor/tailscale.com/ipn/prefs.go b/vendor/tailscale.com/ipn/prefs.go index e79da0d093..356359533e 100644 --- a/vendor/tailscale.com/ipn/prefs.go +++ b/vendor/tailscale.com/ipn/prefs.go @@ -23,6 +23,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/types/persist" "tailscale.com/types/preftype" + "tailscale.com/types/views" "tailscale.com/util/dnsname" ) @@ -195,6 +196,10 @@ type Prefs struct { // and CLI. ProfileName string `json:",omitempty"` + // AutoUpdate sets the auto-update preferences for the node agent. See + // AutoUpdatePrefs docs for more details. + AutoUpdate AutoUpdatePrefs + // The Persist field is named 'Config' in the file for backward // compatibility with earlier versions. // TODO(apenwarr): We should move this out of here, it's not a pref. @@ -203,6 +208,18 @@ type Prefs struct { Persist *persist.Persist `json:"Config"` } +// AutoUpdatePrefs are the auto update settings for the node agent. +type AutoUpdatePrefs struct { + // Check specifies whether background checks for updates are enabled. When + // enabled, tailscaled will periodically check for available updates and + // notify the user about them. + Check bool + // Apply specifies whether background auto-updates are enabled. When + // enabled, tailscaled will apply available updates in the background. + // Check must also be set when Apply is set. + Apply bool +} + // MaskedPrefs is a Prefs with an associated bitmask of which fields are set. type MaskedPrefs struct { Prefs @@ -228,6 +245,7 @@ type MaskedPrefs struct { NetfilterModeSet bool `json:",omitempty"` OperatorUserSet bool `json:",omitempty"` ProfileNameSet bool `json:",omitempty"` + AutoUpdateSet bool `json:",omitempty"` } // ApplyEdits mutates p, assigning fields from m.Prefs for each MaskedPrefs @@ -283,6 +301,12 @@ func (m *MaskedPrefs) Pretty() string { if v.Type().Elem().Kind() == reflect.String { return "%s=%q" } + case reflect.Struct: + return "%s=%+v" + case reflect.Pointer: + if v.Type().Elem().Kind() == reflect.Struct { + return "%s=%+v" + } } return "%s=%v" } @@ -359,6 +383,7 @@ func (p *Prefs) pretty(goos string) string { if p.OperatorUser != "" { fmt.Fprintf(&sb, "op=%q ", p.OperatorUser) } + sb.WriteString(p.AutoUpdate.Pretty()) if p.Persist != nil { sb.WriteString(p.Persist.Pretty()) } else { @@ -413,7 +438,18 @@ func (p *Prefs) Equals(p2 *Prefs) bool { compareIPNets(p.AdvertiseRoutes, p2.AdvertiseRoutes) && compareStrings(p.AdvertiseTags, p2.AdvertiseTags) && p.Persist.Equals(p2.Persist) && - p.ProfileName == p2.ProfileName + p.ProfileName == p2.ProfileName && + p.AutoUpdate == p2.AutoUpdate +} + +func (au AutoUpdatePrefs) Pretty() string { + if au.Apply { + return "update=on " + } + if au.Check { + return "update=check " + } + return "update=off " } func compareIPNets(a, b []netip.Prefix) bool { @@ -458,6 +494,10 @@ func NewPrefs() *Prefs { CorpDNS: true, WantRunning: false, NetfilterMode: preftype.NetfilterOn, + AutoUpdate: AutoUpdatePrefs{ + Check: true, + Apply: false, + }, } } @@ -506,7 +546,7 @@ func (p *Prefs) AdvertisesExitNode() bool { if p == nil { return false } - return tsaddr.ContainsExitRoutes(p.AdvertiseRoutes) + return tsaddr.ContainsExitRoutes(views.SliceOf(p.AdvertiseRoutes)) } // SetAdvertiseExitNode mutates p (if non-nil) to add or remove the two @@ -651,18 +691,11 @@ func PrefsFromBytes(b []byte) (*Prefs, error) { if len(b) == 0 { return p, nil } - persist := &persist.Persist{} - err := json.Unmarshal(b, persist) - if err == nil && (persist.Provider != "" || persist.LoginName != "") { - // old-style relaynode config; import it - p.Persist = persist - } else { - err = json.Unmarshal(b, &p) - if err != nil { - log.Printf("Prefs parse: %v: %v\n", err, b) - } + + if err := json.Unmarshal(b, p); err != nil { + return nil, err } - return p, err + return p, nil } var jsonEscapedZero = []byte(`\u0000`) @@ -722,6 +755,14 @@ type LoginProfile struct { // It is filled in from the UserProfile.LoginName field. Name string + // TailnetMagicDNSName is filled with the MagicDNS suffix for this + // profile's node (even if MagicDNS isn't necessarily in use). + // It will neither start nor end with a period. + // + // TailnetMagicDNSName is only filled from 2023-09-09 forward, + // and will only get backfilled when a profile is the current profile. + TailnetMagicDNSName string + // Key is the StateKey under which the profile is stored. // It is assigned once at profile creation time and never changes. Key StateKey diff --git a/vendor/tailscale.com/ipn/serve.go b/vendor/tailscale.com/ipn/serve.go index 48e3343a1f..b22a5bdb77 100644 --- a/vendor/tailscale.com/ipn/serve.go +++ b/vendor/tailscale.com/ipn/serve.go @@ -12,7 +12,7 @@ import ( "strconv" "strings" - "golang.org/x/exp/slices" + "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" ) @@ -36,12 +36,41 @@ type ServeConfig struct { // AllowFunnel is the set of SNI:port values for which funnel // traffic is allowed, from trusted ingress peers. AllowFunnel map[HostPort]bool `json:",omitempty"` + + // Foreground is a map of an IPN Bus session ID to an alternate foreground + // serve config that's valid for the life of that WatchIPNBus session ID. + // This. This allows the config to specify ephemeral configs that are + // used in the CLI's foreground mode to ensure ungraceful shutdowns + // of either the client or the LocalBackend does not expose ports + // that users are not aware of. + Foreground map[string]*ServeConfig `json:",omitempty"` + + // ETag is the checksum of the serve config that's populated + // by the LocalClient through the HTTP ETag header during a + // GetServeConfig request and is translated to an If-Match header + // during a SetServeConfig request. + ETag string `json:"-"` } // HostPort is an SNI name and port number, joined by a colon. // There is no implicit port 443. It must contain a colon. type HostPort string +// Port extracts just the port number from hp. +// An error is reported in the case that the hp does not +// have a valid numeric port ending. +func (hp HostPort) Port() (uint16, error) { + _, port, err := net.SplitHostPort(string(hp)) + if err != nil { + return 0, err + } + port16, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return 0, err + } + return uint16(port16), nil +} + // A FunnelConn wraps a net.Conn that is coming over a // Funnel connection. It can be used to determine further // information about the connection, like the source address @@ -204,31 +233,25 @@ func (sc *ServeConfig) IsFunnelOn() bool { // CheckFunnelAccess checks whether Funnel access is allowed for the given node // and port. // It checks: -// 1. Funnel is enabled on the Tailnet -// 2. HTTPS is enabled on the Tailnet -// 3. the node has the "funnel" nodeAttr -// 4. the port is allowed for Funnel +// 1. HTTPS is enabled on the Tailnet +// 2. the node has the "funnel" nodeAttr +// 3. the port is allowed for Funnel // -// The nodeAttrs arg should be the node's Self.Capabilities which should contain -// the attribute we're checking for and possibly warning-capabilities for -// Funnel. -func CheckFunnelAccess(port uint16, nodeAttrs []string) error { - if slices.Contains(nodeAttrs, tailcfg.CapabilityWarnFunnelNoInvite) { - return errors.New("Funnel not enabled; See https://tailscale.com/s/no-funnel.") - } - if slices.Contains(nodeAttrs, tailcfg.CapabilityWarnFunnelNoHTTPS) { +// The node arg should be the ipnstate.Status.Self node. +func CheckFunnelAccess(port uint16, node *ipnstate.PeerStatus) error { + if !node.HasCap(tailcfg.CapabilityHTTPS) { return errors.New("Funnel not available; HTTPS must be enabled. See https://tailscale.com/s/https.") } - if !slices.Contains(nodeAttrs, tailcfg.NodeAttrFunnel) { + if !node.HasCap(tailcfg.NodeAttrFunnel) { return errors.New("Funnel not available; \"funnel\" node attribute not set. See https://tailscale.com/s/no-funnel.") } - return checkFunnelPort(port, nodeAttrs) + return CheckFunnelPort(port, node) } -// checkFunnelPort checks whether the given port is allowed for Funnel. +// CheckFunnelPort checks whether the given port is allowed for Funnel. // It uses the tailcfg.CapabilityFunnelPorts nodeAttr to determine the allowed // ports. -func checkFunnelPort(wantedPort uint16, nodeAttrs []string) error { +func CheckFunnelPort(wantedPort uint16, node *ipnstate.PeerStatus) error { deny := func(allowedPorts string) error { if allowedPorts == "" { return fmt.Errorf("port %d is not allowed for funnel", wantedPort) @@ -236,22 +259,49 @@ func checkFunnelPort(wantedPort uint16, nodeAttrs []string) error { return fmt.Errorf("port %d is not allowed for funnel; allowed ports are: %v", wantedPort, allowedPorts) } var portsStr string - for _, attr := range nodeAttrs { - if !strings.HasPrefix(attr, tailcfg.CapabilityFunnelPorts) { - continue - } + parseAttr := func(attr string) (string, error) { u, err := url.Parse(attr) if err != nil { - return deny("") + return "", deny("") } - portsStr = u.Query().Get("ports") + portsStr := u.Query().Get("ports") if portsStr == "" { - return deny("") + return "", deny("") } u.RawQuery = "" - if u.String() != tailcfg.CapabilityFunnelPorts { - return deny("") + if u.String() != string(tailcfg.CapabilityFunnelPorts) { + return "", deny("") + } + return portsStr, nil + } + for attr := range node.CapMap { + attr := string(attr) + if !strings.HasPrefix(attr, string(tailcfg.CapabilityFunnelPorts)) { + continue } + var err error + portsStr, err = parseAttr(attr) + if err != nil { + return err + } + break + } + if portsStr == "" { + for _, attr := range node.Capabilities { + attr := string(attr) + if !strings.HasPrefix(attr, string(tailcfg.CapabilityFunnelPorts)) { + continue + } + var err error + portsStr, err = parseAttr(attr) + if err != nil { + return err + } + break + } + } + if portsStr == "" { + return deny("") } wantedPortString := strconv.Itoa(int(wantedPort)) for _, ps := range strings.Split(portsStr, ",") { @@ -280,3 +330,102 @@ func checkFunnelPort(wantedPort uint16, nodeAttrs []string) error { } return deny(portsStr) } + +// RangeOverTCPs ranges over both background and foreground TCPs. +// If the returned bool from the given f is false, then this function stops +// iterating immediately and does not check other foreground configs. +func (v ServeConfigView) RangeOverTCPs(f func(port uint16, _ TCPPortHandlerView) bool) { + parentCont := true + v.TCP().Range(func(k uint16, v TCPPortHandlerView) (cont bool) { + parentCont = f(k, v) + return parentCont + }) + v.Foreground().Range(func(k string, v ServeConfigView) (cont bool) { + if !parentCont { + return false + } + v.TCP().Range(func(k uint16, v TCPPortHandlerView) (cont bool) { + parentCont = f(k, v) + return parentCont + }) + return parentCont + }) +} + +// RangeOverWebs ranges over both background and foreground Webs. +// If the returned bool from the given f is false, then this function stops +// iterating immediately and does not check other foreground configs. +func (v ServeConfigView) RangeOverWebs(f func(_ HostPort, conf WebServerConfigView) bool) { + parentCont := true + v.Web().Range(func(k HostPort, v WebServerConfigView) (cont bool) { + parentCont = f(k, v) + return parentCont + }) + v.Foreground().Range(func(k string, v ServeConfigView) (cont bool) { + if !parentCont { + return false + } + v.Web().Range(func(k HostPort, v WebServerConfigView) (cont bool) { + parentCont = f(k, v) + return parentCont + }) + return parentCont + }) +} + +// FindTCP returns the first TCP that matches with the given port. It +// prefers a foreground match first followed by a background search if none +// existed. +func (v ServeConfigView) FindTCP(port uint16) (res TCPPortHandlerView, ok bool) { + v.Foreground().Range(func(_ string, v ServeConfigView) (cont bool) { + res, ok = v.TCP().GetOk(port) + return !ok + }) + if ok { + return res, ok + } + return v.TCP().GetOk(port) +} + +// FindWeb returns the first Web that matches with the given HostPort. It +// prefers a foreground match first followed by a background search if none +// existed. +func (v ServeConfigView) FindWeb(hp HostPort) (res WebServerConfigView, ok bool) { + v.Foreground().Range(func(_ string, v ServeConfigView) (cont bool) { + res, ok = v.Web().GetOk(hp) + return !ok + }) + if ok { + return res, ok + } + return v.Web().GetOk(hp) +} + +// HasAllowFunnel returns whether this config has at least one AllowFunnel +// set in the background or foreground configs. +func (v ServeConfigView) HasAllowFunnel() bool { + return v.AllowFunnel().Len() > 0 || func() bool { + var exists bool + v.Foreground().Range(func(k string, v ServeConfigView) (cont bool) { + exists = v.AllowFunnel().Len() > 0 + return !exists + }) + return exists + }() +} + +// FindFunnel reports whether target exists in in either the background AllowFunnel +// or any of the foreground configs. +func (v ServeConfigView) HasFunnelForTarget(target HostPort) bool { + if v.AllowFunnel().Get(target) { + return true + } + var exists bool + v.Foreground().Range(func(_ string, v ServeConfigView) (cont bool) { + if exists = v.AllowFunnel().Get(target); exists { + return false + } + return true + }) + return exists +} diff --git a/vendor/tailscale.com/ipn/store.go b/vendor/tailscale.com/ipn/store.go index ee5a238a7e..3bef012bab 100644 --- a/vendor/tailscale.com/ipn/store.go +++ b/vendor/tailscale.com/ipn/store.go @@ -4,6 +4,7 @@ package ipn import ( + "bytes" "context" "errors" "fmt" @@ -71,9 +72,22 @@ type StateStore interface { // ErrStateNotExist) if the ID doesn't have associated state. ReadState(id StateKey) ([]byte, error) // WriteState saves bs as the state associated with ID. + // + // Callers should generally use the ipn.WriteState wrapper func + // instead, which only writes if the value is different from what's + // already in the store. WriteState(id StateKey, bs []byte) error } +// WriteState is a wrapper around store.WriteState that only writes if +// the value is different from what's already in the store. +func WriteState(store StateStore, id StateKey, v []byte) error { + if was, err := store.ReadState(id); err == nil && bytes.Equal(was, v) { + return nil + } + return store.WriteState(id, v) +} + // StateStoreDialerSetter is an optional interface that StateStores // can implement to allow the caller to set a custom dialer. type StateStoreDialerSetter interface { @@ -91,5 +105,5 @@ func ReadStoreInt(store StateStore, id StateKey) (int64, error) { // PutStoreInt puts an integer into a StateStore. func PutStoreInt(store StateStore, id StateKey, val int64) error { - return store.WriteState(id, fmt.Appendf(nil, "%d", val)) + return WriteState(store, id, fmt.Appendf(nil, "%d", val)) } diff --git a/vendor/tailscale.com/ipn/store/kubestore/store_kube.go b/vendor/tailscale.com/ipn/store/kubestore/store_kube.go index 65c73c1bd7..3eb29898e8 100644 --- a/vendor/tailscale.com/ipn/store/kubestore/store_kube.go +++ b/vendor/tailscale.com/ipn/store/kubestore/store_kube.go @@ -19,6 +19,7 @@ import ( // Store is an ipn.StateStore that uses a Kubernetes Secret for persistence. type Store struct { client *kube.Client + canPatch bool secretName string } @@ -28,8 +29,13 @@ func New(_ logger.Logf, secretName string) (*Store, error) { if err != nil { return nil, err } + canPatch, err := c.CheckSecretPermissions(context.Background(), secretName) + if err != nil { + return nil, err + } return &Store{ client: c, + canPatch: canPatch, secretName: secretName, }, nil } @@ -93,6 +99,19 @@ func (s *Store) WriteState(id ipn.StateKey, bs []byte) error { } return err } + if s.canPatch { + m := []kube.JSONPatch{ + { + Op: "add", + Path: "/data/" + sanitizeKey(id), + Value: bs, + }, + } + if err := s.client.JSONPatchSecret(ctx, s.secretName, m); err != nil { + return err + } + return nil + } secret.Data[sanitizeKey(id)] = bs if err := s.client.UpdateSecret(ctx, secret); err != nil { return err diff --git a/vendor/tailscale.com/kube/client.go b/vendor/tailscale.com/kube/client.go index ef62b74cea..f4befd1c84 100644 --- a/vendor/tailscale.com/kube/client.go +++ b/vendor/tailscale.com/kube/client.go @@ -233,15 +233,16 @@ func (c *Client) UpdateSecret(ctx context.Context, s *Secret) error { // // https://tools.ietf.org/html/rfc6902 type JSONPatch struct { - Op string `json:"op"` - Path string `json:"path"` + Op string `json:"op"` + Path string `json:"path"` + Value any `json:"value,omitempty"` } // JSONPatchSecret updates a secret in the Kubernetes API using a JSON patch. // It currently (2023-03-02) only supports the "remove" operation. func (c *Client) JSONPatchSecret(ctx context.Context, name string, patch []JSONPatch) error { for _, p := range patch { - if p.Op != "remove" { + if p.Op != "remove" && p.Op != "add" { panic(fmt.Errorf("unsupported JSON patch operation: %q", p.Op)) } } diff --git a/vendor/tailscale.com/log/logheap/logheap.go b/vendor/tailscale.com/log/logheap/logheap.go deleted file mode 100644 index f645ec4c97..0000000000 --- a/vendor/tailscale.com/log/logheap/logheap.go +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !js - -// Package logheap logs a heap pprof profile. -package logheap - -import ( - "bytes" - "context" - "log" - "net/http" - "runtime" - "runtime/pprof" - "time" -) - -// LogHeap uploads a JSON logtail record with the base64 heap pprof by means -// of an HTTP POST request to the endpoint referred to in postURL. -func LogHeap(postURL string) { - if postURL == "" { - return - } - runtime.GC() - buf := new(bytes.Buffer) - if err := pprof.WriteHeapProfile(buf); err != nil { - log.Printf("LogHeap: %v", err) - return - } - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - req, err := http.NewRequestWithContext(ctx, "POST", postURL, buf) - if err != nil { - log.Printf("LogHeap: %v", err) - return - } - res, err := http.DefaultClient.Do(req) - if err != nil { - log.Printf("LogHeap: %v", err) - return - } - defer res.Body.Close() -} diff --git a/vendor/tailscale.com/log/logheap/logheap_js.go b/vendor/tailscale.com/log/logheap/logheap_js.go deleted file mode 100644 index 35453b482c..0000000000 --- a/vendor/tailscale.com/log/logheap/logheap_js.go +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package logheap - -func LogHeap(postURL string) { -} diff --git a/vendor/tailscale.com/log/sockstatlog/logger.go b/vendor/tailscale.com/log/sockstatlog/logger.go index 4f522f17d3..c1f96e8cce 100644 --- a/vendor/tailscale.com/log/sockstatlog/logger.go +++ b/vendor/tailscale.com/log/sockstatlog/logger.go @@ -114,7 +114,7 @@ func NewLogger(logdir string, logf logger.Logf, logID logid.PublicID, netMon *ne logger := &Logger{ logf: logf, filch: filch, - tr: logpolicy.NewLogtailTransport(logtail.DefaultHost, netMon), + tr: logpolicy.NewLogtailTransport(logtail.DefaultHost, netMon, logf), } logger.logger = logtail.NewLogger(logtail.Config{ BaseURL: logpolicy.LogURL(), diff --git a/vendor/tailscale.com/logpolicy/logpolicy.go b/vendor/tailscale.com/logpolicy/logpolicy.go index 7becc934fd..c11aaf3bce 100644 --- a/vendor/tailscale.com/logpolicy/logpolicy.go +++ b/vendor/tailscale.com/logpolicy/logpolicy.go @@ -13,7 +13,6 @@ import ( "crypto/tls" "encoding/json" "errors" - "flag" "fmt" "io" "log" @@ -49,13 +48,12 @@ import ( "tailscale.com/util/clientmetric" "tailscale.com/util/must" "tailscale.com/util/racebuild" + "tailscale.com/util/testenv" "tailscale.com/util/winutil" "tailscale.com/version" "tailscale.com/version/distro" ) -func inTest() bool { return flag.Lookup("test.v") != nil } - var getLogTargetOnce struct { sync.Once v string // URL of logs server, or empty for default @@ -110,6 +108,8 @@ type Policy struct { Logtail *logtail.Logger // PublicID is the logger's instance identifier. PublicID logid.PublicID + // Logf is where to write informational messages about this Logger. + Logf logger.Logf } // NewConfig creates a Config with collection and a newly generated PrivateID. @@ -310,7 +310,7 @@ func winProgramDataAccessible(dir string) bool { // log state for that command exists in dir, then the log state is // moved from wherever it does exist, into dir. Leftover logs state // in / and $CACHE_DIRECTORY is deleted. -func tryFixLogStateLocation(dir, cmdname string) { +func tryFixLogStateLocation(dir, cmdname string, logf logger.Logf) { switch runtime.GOOS { case "linux", "freebsd", "openbsd": // These are the OSes where we might have written stuff into @@ -320,13 +320,13 @@ func tryFixLogStateLocation(dir, cmdname string) { return } if cmdname == "" { - log.Printf("[unexpected] no cmdname given to tryFixLogStateLocation, please file a bug at https://github.com/tailscale/tailscale") + logf("[unexpected] no cmdname given to tryFixLogStateLocation, please file a bug at https://github.com/tailscale/tailscale") return } if dir == "/" { // Trying to store things in / still. That's a bug, but don't // abort hard. - log.Printf("[unexpected] storing logging config in /, please file a bug at https://github.com/tailscale/tailscale") + logf("[unexpected] storing logging config in /, please file a bug at https://github.com/tailscale/tailscale") return } if os.Getuid() != 0 { @@ -383,7 +383,7 @@ func tryFixLogStateLocation(dir, cmdname string) { existsInRoot, err := checkExists("/") if err != nil { - log.Printf("checking for configs in /: %v", err) + logf("checking for configs in /: %v", err) return } existsInCache := false @@ -391,12 +391,12 @@ func tryFixLogStateLocation(dir, cmdname string) { if cacheDir != "" { existsInCache, err = checkExists("/var/cache/tailscale") if err != nil { - log.Printf("checking for configs in %s: %v", cacheDir, err) + logf("checking for configs in %s: %v", cacheDir, err) } } existsInDest, err := checkExists(dir) if err != nil { - log.Printf("checking for configs in %s: %v", dir, err) + logf("checking for configs in %s: %v", dir, err) return } @@ -411,13 +411,13 @@ func tryFixLogStateLocation(dir, cmdname string) { // CACHE_DIRECTORY takes precedence over /, move files from // there. if err := moveFiles(cacheDir); err != nil { - log.Print(err) + logf("%v", err) return } case existsInRoot: // Files from root is better than nothing. if err := moveFiles("/"); err != nil { - log.Print(err) + logf("%v", err) return } } @@ -439,27 +439,32 @@ func tryFixLogStateLocation(dir, cmdname string) { if os.IsNotExist(err) { continue } else if err != nil { - log.Printf("stat %q: %v", p, err) + logf("stat %q: %v", p, err) return } if err := os.Remove(p); err != nil { - log.Printf("rm %q: %v", p, err) + logf("rm %q: %v", p, err) } } } } -// New returns a new log policy (a logger and its instance ID) for a -// given collection name. -// The netMon parameter is optional; if non-nil it's used to do faster interface lookups. -func New(collection string, netMon *netmon.Monitor) *Policy { - return NewWithConfigPath(collection, "", "", netMon) +// New returns a new log policy (a logger and its instance ID) for a given +// collection name. +// +// The netMon parameter is optional; if non-nil it's used to do faster +// interface lookups. +// +// The logf parameter is optional; if non-nil, information logs (e.g. when +// migrating state) are sent to that logger, and global changes to the log +// package are avoided. If nil, logs will be printed using log.Printf. +func New(collection string, netMon *netmon.Monitor, logf logger.Logf) *Policy { + return NewWithConfigPath(collection, "", "", netMon, logf) } -// NewWithConfigPath is identical to New, -// but uses the specified directory and command name. -// If either is empty, it derives them automatically. -func NewWithConfigPath(collection, dir, cmdName string, netMon *netmon.Monitor) *Policy { +// NewWithConfigPath is identical to New, but uses the specified directory and +// command name. If either is empty, it derives them automatically. +func NewWithConfigPath(collection, dir, cmdName string, netMon *netmon.Monitor, logf logger.Logf) *Policy { var lflags int if term.IsTerminal(2) || runtime.GOOS == "windows" { lflags = 0 @@ -488,7 +493,12 @@ func NewWithConfigPath(collection, dir, cmdName string, netMon *netmon.Monitor) if cmdName == "" { cmdName = version.CmdName() } - tryFixLogStateLocation(dir, cmdName) + + useStdLogger := logf == nil + if useStdLogger { + logf = log.Printf + } + tryFixLogStateLocation(dir, cmdName, logf) cfgPath := filepath.Join(dir, fmt.Sprintf("%s.log.conf", cmdName)) @@ -556,7 +566,7 @@ func NewWithConfigPath(collection, dir, cmdName string, netMon *netmon.Monitor) } return w }, - HTTPC: &http.Client{Transport: NewLogtailTransport(logtail.DefaultHost, netMon)}, + HTTPC: &http.Client{Transport: NewLogtailTransport(logtail.DefaultHost, netMon, logf)}, } if collection == logtail.CollectionNode { conf.MetricsDelta = clientmetric.EncodeLogTailMetricsDelta @@ -564,14 +574,14 @@ func NewWithConfigPath(collection, dir, cmdName string, netMon *netmon.Monitor) conf.IncludeProcSequence = true } - if envknob.NoLogsNoSupport() || inTest() { - log.Println("You have disabled logging. Tailscale will not be able to provide support.") + if envknob.NoLogsNoSupport() || testenv.InTest() { + logf("You have disabled logging. Tailscale will not be able to provide support.") conf.HTTPC = &http.Client{Transport: noopPretendSuccessTransport{}} } else if val := getLogTarget(); val != "" { - log.Println("You have enabled a non-default log target. Doing without being told to by Tailscale staff or your network administrator will make getting support difficult.") + logf("You have enabled a non-default log target. Doing without being told to by Tailscale staff or your network administrator will make getting support difficult.") conf.BaseURL = val u, _ := url.Parse(val) - conf.HTTPC = &http.Client{Transport: NewLogtailTransport(u.Host, netMon)} + conf.HTTPC = &http.Client{Transport: NewLogtailTransport(u.Host, netMon, logf)} } filchOptions := filch.Options{ @@ -588,7 +598,7 @@ func NewWithConfigPath(collection, dir, cmdName string, netMon *netmon.Monitor) filchOptions.MaxFileSize = 1 << 20 } else { // not a fatal error, we can leave the log files on the spinning disk - log.Printf("Unable to create /tmp directory for log storage: %v\n", err) + logf("Unable to create /tmp directory for log storage: %v\n", err) } } @@ -599,7 +609,7 @@ func NewWithConfigPath(collection, dir, cmdName string, netMon *netmon.Monitor) conf.Stderr = filchBuf.OrigStderr } } - lw := logtail.NewLogger(conf, log.Printf) + lw := logtail.NewLogger(conf, logf) var logOutput io.Writer = lw @@ -612,24 +622,27 @@ func NewWithConfigPath(collection, dir, cmdName string, netMon *netmon.Monitor) } } - log.SetFlags(0) // other log flags are set on console, not here - log.SetOutput(logOutput) + if useStdLogger { + log.SetFlags(0) // other log flags are set on console, not here + log.SetOutput(logOutput) + } - log.Printf("Program starting: v%v, Go %v: %#v", + logf("Program starting: v%v, Go %v: %#v", version.Long(), goVersion(), os.Args) - log.Printf("LogID: %v", newc.PublicID) + logf("LogID: %v", newc.PublicID) if filchErr != nil { - log.Printf("filch failed: %v", filchErr) + logf("filch failed: %v", filchErr) } if earlyErrBuf.Len() != 0 { - log.Printf("%s", earlyErrBuf.Bytes()) + logf("%s", earlyErrBuf.Bytes()) } return &Policy{ Logtail: lw, PublicID: newc.PublicID, + Logf: logf, } } @@ -666,7 +679,7 @@ func (p *Policy) Close() { // log upload if it can be done before ctx is canceled. func (p *Policy) Shutdown(ctx context.Context) error { if p.Logtail != nil { - log.Printf("flushing log.") + p.Logf("flushing log.") return p.Logtail.Shutdown(ctx) } return nil @@ -680,14 +693,14 @@ func (p *Policy) Shutdown(ctx context.Context) error { // for the benefit of older OS platforms which might not include it. // // The netMon parameter is optional; if non-nil it's used to do faster interface lookups. -func MakeDialFunc(netMon *netmon.Monitor) func(ctx context.Context, netw, addr string) (net.Conn, error) { +func MakeDialFunc(netMon *netmon.Monitor, logf logger.Logf) func(ctx context.Context, netw, addr string) (net.Conn, error) { return func(ctx context.Context, netw, addr string) (net.Conn, error) { - return dialContext(ctx, netw, addr, netMon) + return dialContext(ctx, netw, addr, netMon, logf) } } -func dialContext(ctx context.Context, netw, addr string, netMon *netmon.Monitor) (net.Conn, error) { - nd := netns.FromDialer(log.Printf, netMon, &net.Dialer{ +func dialContext(ctx context.Context, netw, addr string, netMon *netmon.Monitor, logf logger.Logf) (net.Conn, error) { + nd := netns.FromDialer(logf, netMon, &net.Dialer{ Timeout: 30 * time.Second, KeepAlive: netknob.PlatformTCPKeepAlive(), }) @@ -708,7 +721,7 @@ func dialContext(ctx context.Context, netw, addr string, netMon *netmon.Monitor) err = errors.New(res.Status) } if err != nil { - log.Printf("logtail: CONNECT response error from tailscaled: %v", err) + logf("logtail: CONNECT response error from tailscaled: %v", err) c.Close() } else { dialLog.Printf("connected via tailscaled") @@ -718,26 +731,30 @@ func dialContext(ctx context.Context, netw, addr string, netMon *netmon.Monitor) } // If we failed to dial, try again with bootstrap DNS. - log.Printf("logtail: dial %q failed: %v (in %v), trying bootstrap...", addr, err, d) + logf("logtail: dial %q failed: %v (in %v), trying bootstrap...", addr, err, d) dnsCache := &dnscache.Resolver{ Forward: dnscache.Get().Forward, // use default cache's forwarder UseLastGood: true, - LookupIPFallback: dnsfallback.MakeLookupFunc(log.Printf, netMon), + LookupIPFallback: dnsfallback.MakeLookupFunc(logf, netMon), NetMon: netMon, } dialer := dnscache.Dialer(nd.DialContext, dnsCache) c, err = dialer(ctx, netw, addr) if err == nil { - log.Printf("logtail: bootstrap dial succeeded") + logf("logtail: bootstrap dial succeeded") } return c, err } // NewLogtailTransport returns an HTTP Transport particularly suited to uploading // logs to the given host name. See DialContext for details on how it works. +// // The netMon parameter is optional; if non-nil it's used to do faster interface lookups. -func NewLogtailTransport(host string, netMon *netmon.Monitor) http.RoundTripper { - if inTest() { +// +// The logf parameter is optional; if non-nil, logs are printed using the +// provided function; if nil, log.Printf will be used instead. +func NewLogtailTransport(host string, netMon *netmon.Monitor, logf logger.Logf) http.RoundTripper { + if testenv.InTest() { return noopPretendSuccessTransport{} } // Start with a copy of http.DefaultTransport and tweak it a bit. @@ -752,7 +769,10 @@ func NewLogtailTransport(host string, netMon *netmon.Monitor) http.RoundTripper tr.DisableCompression = true // Log whenever we dial: - tr.DialContext = MakeDialFunc(netMon) + if logf == nil { + logf = log.Printf + } + tr.DialContext = MakeDialFunc(netMon, logf) // We're contacting exactly 1 hostname, so the default's 100 // max idle conns is very high for our needs. Even 2 is diff --git a/vendor/tailscale.com/logtail/backoff/backoff.go b/vendor/tailscale.com/logtail/backoff/backoff.go index ffec64ecd7..72831f5926 100644 --- a/vendor/tailscale.com/logtail/backoff/backoff.go +++ b/vendor/tailscale.com/logtail/backoff/backoff.go @@ -9,6 +9,7 @@ import ( "math/rand" "time" + "tailscale.com/tstime" "tailscale.com/types/logger" ) @@ -23,9 +24,8 @@ type Backoff struct { // logf is the function used for log messages when backing off. logf logger.Logf - // NewTimer is the function that acts like time.NewTimer. - // It's for use in unit tests. - NewTimer func(time.Duration) *time.Timer + // tstime.Clock.NewTimer is used instead time.NewTimer. + Clock tstime.Clock // LogLongerThan sets the minimum time of a single backoff interval // before we mention it in the log. @@ -40,7 +40,7 @@ func NewBackoff(name string, logf logger.Logf, maxBackoff time.Duration) *Backof name: name, logf: logf, maxBackoff: maxBackoff, - NewTimer: time.NewTimer, + Clock: tstime.StdClock{}, } } @@ -72,10 +72,10 @@ func (b *Backoff) BackOff(ctx context.Context, err error) { if d >= b.LogLongerThan { b.logf("%s: [v1] backoff: %d msec", b.name, d.Milliseconds()) } - t := b.NewTimer(d) + t, tChannel := b.Clock.NewTimer(d) select { case <-ctx.Done(): t.Stop() - case <-t.C: + case <-tChannel: } } diff --git a/vendor/tailscale.com/logtail/filch/filch_wasm.go b/vendor/tailscale.com/logtail/filch/filch_stub.go similarity index 89% rename from vendor/tailscale.com/logtail/filch/filch_wasm.go rename to vendor/tailscale.com/logtail/filch/filch_stub.go index 019ee4e144..3bb82b1906 100644 --- a/vendor/tailscale.com/logtail/filch/filch_wasm.go +++ b/vendor/tailscale.com/logtail/filch/filch_stub.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build wasm || plan9 || tamago + package filch import ( diff --git a/vendor/tailscale.com/logtail/filch/filch_unix.go b/vendor/tailscale.com/logtail/filch/filch_unix.go index 34cce59b40..2eae70aceb 100644 --- a/vendor/tailscale.com/logtail/filch/filch_unix.go +++ b/vendor/tailscale.com/logtail/filch/filch_unix.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !windows && !wasm +//go:build !windows && !wasm && !plan9 && !tamago package filch diff --git a/vendor/tailscale.com/logtail/logtail.go b/vendor/tailscale.com/logtail/logtail.go index 26abe09185..4544af9d79 100644 --- a/vendor/tailscale.com/logtail/logtail.go +++ b/vendor/tailscale.com/logtail/logtail.go @@ -22,7 +22,6 @@ import ( "time" "tailscale.com/envknob" - "tailscale.com/net/interfaces" "tailscale.com/net/netmon" "tailscale.com/net/sockstats" "tailscale.com/tstime" @@ -49,18 +48,18 @@ type Encoder interface { } type Config struct { - Collection string // collection name, a domain name - PrivateID logid.PrivateID // private ID for the primary log stream - CopyPrivateID logid.PrivateID // private ID for a log stream that is a superset of this log stream - BaseURL string // if empty defaults to "https://log.tailscale.io" - HTTPC *http.Client // if empty defaults to http.DefaultClient - SkipClientTime bool // if true, client_time is not written to logs - LowMemory bool // if true, logtail minimizes memory use - TimeNow func() time.Time // if set, substitutes uses of time.Now - Stderr io.Writer // if set, logs are sent here instead of os.Stderr - StderrLevel int // max verbosity level to write to stderr; 0 means the non-verbose messages only - Buffer Buffer // temp storage, if nil a MemoryBuffer - NewZstdEncoder func() Encoder // if set, used to compress logs for transmission + Collection string // collection name, a domain name + PrivateID logid.PrivateID // private ID for the primary log stream + CopyPrivateID logid.PrivateID // private ID for a log stream that is a superset of this log stream + BaseURL string // if empty defaults to "https://log.tailscale.io" + HTTPC *http.Client // if empty defaults to http.DefaultClient + SkipClientTime bool // if true, client_time is not written to logs + LowMemory bool // if true, logtail minimizes memory use + Clock tstime.Clock // if set, Clock.Now substitutes uses of time.Now + Stderr io.Writer // if set, logs are sent here instead of os.Stderr + StderrLevel int // max verbosity level to write to stderr; 0 means the non-verbose messages only + Buffer Buffer // temp storage, if nil a MemoryBuffer + NewZstdEncoder func() Encoder // if set, used to compress logs for transmission // MetricsDelta, if non-nil, is a func that returns an encoding // delta in clientmetrics to upload alongside existing logs. @@ -94,8 +93,8 @@ func NewLogger(cfg Config, logf tslogger.Logf) *Logger { if cfg.HTTPC == nil { cfg.HTTPC = http.DefaultClient } - if cfg.TimeNow == nil { - cfg.TimeNow = time.Now + if cfg.Clock == nil { + cfg.Clock = tstime.StdClock{} } if cfg.Stderr == nil { cfg.Stderr = os.Stderr @@ -144,9 +143,8 @@ func NewLogger(cfg Config, logf tslogger.Logf) *Logger { drainWake: make(chan struct{}, 1), sentinel: make(chan int32, 16), flushDelayFn: cfg.FlushDelayFn, - timeNow: cfg.TimeNow, + clock: cfg.Clock, metricsDelta: cfg.MetricsDelta, - sockstatsLabel: sockstats.LabelLogtailLogger, procID: procID, includeProcSequence: cfg.IncludeProcSequence, @@ -154,6 +152,7 @@ func NewLogger(cfg Config, logf tslogger.Logf) *Logger { shutdownStart: make(chan struct{}), shutdownDone: make(chan struct{}), } + l.SetSockstatsLabel(sockstats.LabelLogtailLogger) if cfg.NewZstdEncoder != nil { l.zstdEncoder = cfg.NewZstdEncoder() } @@ -181,27 +180,32 @@ type Logger struct { flushDelayFn func() time.Duration // negative or zero return value to upload aggressively, or >0 to batch at this delay flushPending atomic.Bool sentinel chan int32 - timeNow func() time.Time + clock tstime.Clock zstdEncoder Encoder uploadCancel func() explainedRaw bool metricsDelta func() string // or nil privateID logid.PrivateID httpDoCalls atomic.Int32 - sockstatsLabel sockstats.Label + sockstatsLabel atomicSocktatsLabel procID uint32 includeProcSequence bool writeLock sync.Mutex // guards procSequence, flushTimer, buffer.Write calls procSequence uint64 - flushTimer *time.Timer // used when flushDelay is >0 + flushTimer tstime.TimerController // used when flushDelay is >0 shutdownStartMu sync.Mutex // guards the closing of shutdownStart shutdownStart chan struct{} // closed when shutdown begins shutdownDone chan struct{} // closed when shutdown complete } +type atomicSocktatsLabel struct{ p atomic.Uint32 } + +func (p *atomicSocktatsLabel) Load() sockstats.Label { return sockstats.Label(p.p.Load()) } +func (p *atomicSocktatsLabel) Store(label sockstats.Label) { p.p.Store(uint32(label)) } + // SetVerbosityLevel controls the verbosity level that should be // written to stderr. 0 is the default (not verbose). Levels 1 or higher // are increasingly verbose. @@ -219,7 +223,7 @@ func (l *Logger) SetNetMon(lm *netmon.Monitor) { // SetSockstatsLabel sets the label used in sockstat logs to identify network traffic from this logger. func (l *Logger) SetSockstatsLabel(label sockstats.Label) { - l.sockstatsLabel = label + l.sockstatsLabel.Store(label) } // PrivateID returns the logger's private log ID. @@ -375,7 +379,7 @@ func (l *Logger) uploading(ctx context.Context) { retryAfter, err := l.upload(ctx, body, origlen) if err != nil { numFailures++ - firstFailure = time.Now() + firstFailure = l.clock.Now() if !l.internetUp() { fmt.Fprintf(l.stderr, "logtail: internet down; waiting\n") @@ -398,7 +402,7 @@ func (l *Logger) uploading(ctx context.Context) { } else { // Only print a success message after recovery. if numFailures > 0 { - fmt.Fprintf(l.stderr, "logtail: upload succeeded after %d failures and %s\n", numFailures, time.Since(firstFailure).Round(time.Second)) + fmt.Fprintf(l.stderr, "logtail: upload succeeded after %d failures and %s\n", numFailures, l.clock.Since(firstFailure).Round(time.Second)) } break } @@ -422,8 +426,8 @@ func (l *Logger) internetUp() bool { func (l *Logger) awaitInternetUp(ctx context.Context) { upc := make(chan bool, 1) - defer l.netMonitor.RegisterChangeCallback(func(changed bool, st *interfaces.State) { - if st.AnyInterfaceUp() { + defer l.netMonitor.RegisterChangeCallback(func(delta *netmon.ChangeDelta) { + if delta.New.AnyInterfaceUp() { select { case upc <- true: default: @@ -445,7 +449,7 @@ func (l *Logger) awaitInternetUp(ctx context.Context) { // origlen of -1 indicates that the body is not compressed. func (l *Logger) upload(ctx context.Context, body []byte, origlen int) (retryAfter time.Duration, err error) { const maxUploadTime = 45 * time.Second - ctx = sockstats.WithSockStats(ctx, l.sockstatsLabel, l.Logf) + ctx = sockstats.WithSockStats(ctx, l.sockstatsLabel.Load(), l.Logf) ctx, cancel := context.WithTimeout(ctx, maxUploadTime) defer cancel() @@ -540,7 +544,7 @@ func (l *Logger) sendLocked(jsonBlob []byte) (int, error) { if flushDelay > 0 { if l.flushPending.CompareAndSwap(false, true) { if l.flushTimer == nil { - l.flushTimer = time.AfterFunc(flushDelay, l.tryDrainWake) + l.flushTimer = l.clock.AfterFunc(flushDelay, l.tryDrainWake) } else { l.flushTimer.Reset(flushDelay) } @@ -554,7 +558,7 @@ func (l *Logger) sendLocked(jsonBlob []byte) (int, error) { // TODO: instead of allocating, this should probably just append // directly into the output log buffer. func (l *Logger) encodeText(buf []byte, skipClientTime bool, procID uint32, procSequence uint64, level int) []byte { - now := l.timeNow() + now := l.clock.Now() // Factor in JSON encoding overhead to try to only do one alloc // in the make below (so appends don't resize the buffer). @@ -669,7 +673,7 @@ func (l *Logger) encodeLocked(buf []byte, level int) []byte { return l.encodeText(buf, l.skipClientTime, l.procID, l.procSequence, level) // text fast-path } - now := l.timeNow() + now := l.clock.Now() obj := make(map[string]any) if err := json.Unmarshal(buf, &obj); err != nil { diff --git a/vendor/tailscale.com/metrics/metrics.go b/vendor/tailscale.com/metrics/metrics.go index 5a323dc389..a07ddccae5 100644 --- a/vendor/tailscale.com/metrics/metrics.go +++ b/vendor/tailscale.com/metrics/metrics.go @@ -5,7 +5,13 @@ // Tailscale for monitoring. package metrics -import "expvar" +import ( + "expvar" + "fmt" + "io" + "slices" + "strings" +) // Set is a string-to-Var map variable that satisfies the expvar.Var // interface. @@ -45,6 +51,14 @@ func (m *LabelMap) Get(key string) *expvar.Int { return m.Map.Get(key).(*expvar.Int) } +// GetIncrFunc returns a function that increments the expvar.Int named by key. +// +// Most callers should not need this; it exists to satisfy an +// interface elsewhere. +func (m *LabelMap) GetIncrFunc(key string) func(delta int64) { + return m.Get(key).Add +} + // GetFloat returns a direct pointer to the expvar.Float for key, creating it // if necessary. func (m *LabelMap) GetFloat(key string) *expvar.Float { @@ -58,3 +72,92 @@ func (m *LabelMap) GetFloat(key string) *expvar.Float { func CurrentFDs() int { return currentFDs() } + +// Histogram is a histogram of values. +// It should be created with NewHistogram. +type Histogram struct { + // buckets is a list of bucket boundaries, in increasing order. + buckets []float64 + + // bucketStrings is a list of the same buckets, but as strings. + // This are allocated once at creation time by NewHistogram. + bucketStrings []string + + bucketVars []expvar.Int + sum expvar.Float + count expvar.Int +} + +// NewHistogram returns a new histogram that reports to the given +// expvar map under the given name. +// +// The buckets are the boundaries of the histogram buckets, in +// increasing order. The last bucket is +Inf. +func NewHistogram(buckets []float64) *Histogram { + if !slices.IsSorted(buckets) { + panic("buckets must be sorted") + } + labels := make([]string, len(buckets)) + for i, b := range buckets { + labels[i] = fmt.Sprintf("%v", b) + } + h := &Histogram{ + buckets: buckets, + bucketStrings: labels, + bucketVars: make([]expvar.Int, len(buckets)), + } + return h +} + +// Observe records a new observation in the histogram. +func (h *Histogram) Observe(v float64) { + h.sum.Add(v) + h.count.Add(1) + for i, b := range h.buckets { + if v <= b { + h.bucketVars[i].Add(1) + } + } +} + +// String returns a JSON representation of the histogram. +// This is used to satisfy the expvar.Var interface. +func (h *Histogram) String() string { + var b strings.Builder + fmt.Fprintf(&b, "{") + first := true + h.Do(func(kv expvar.KeyValue) { + if !first { + fmt.Fprintf(&b, ",") + } + fmt.Fprintf(&b, "%q: ", kv.Key) + if kv.Value != nil { + fmt.Fprintf(&b, "%v", kv.Value) + } else { + fmt.Fprint(&b, "null") + } + first = false + }) + fmt.Fprintf(&b, ",\"sum\": %v", &h.sum) + fmt.Fprintf(&b, ",\"count\": %v", &h.count) + fmt.Fprintf(&b, "}") + return b.String() +} + +// Do calls f for each bucket in the histogram. +func (h *Histogram) Do(f func(expvar.KeyValue)) { + for i := range h.bucketVars { + f(expvar.KeyValue{Key: h.bucketStrings[i], Value: &h.bucketVars[i]}) + } + f(expvar.KeyValue{Key: "+Inf", Value: &h.count}) +} + +// PromExport writes the histogram to w in Prometheus exposition format. +func (h *Histogram) PromExport(w io.Writer, name string) { + fmt.Fprintf(w, "# TYPE %s histogram\n", name) + h.Do(func(kv expvar.KeyValue) { + fmt.Fprintf(w, "%s_bucket{le=%q} %v\n", name, kv.Key, kv.Value) + }) + fmt.Fprintf(w, "%s_sum %v\n", name, &h.sum) + fmt.Fprintf(w, "%s_count %v\n", name, &h.count) +} diff --git a/vendor/tailscale.com/net/dns/direct.go b/vendor/tailscale.com/net/dns/direct.go index 0cae70d0e2..e9279d13a2 100644 --- a/vendor/tailscale.com/net/dns/direct.go +++ b/vendor/tailscale.com/net/dns/direct.go @@ -27,11 +27,6 @@ import ( "tailscale.com/version/distro" ) -const ( - backupConf = "/etc/resolv.pre-tailscale-backup.conf" - resolvConf = "/etc/resolv.conf" -) - // writeResolvConf writes DNS configuration in resolv.conf format to the given writer. func writeResolvConf(w io.Writer, servers []netip.Addr, domains []dnsname.FQDN) error { c := &resolvconffile.Config{ diff --git a/vendor/tailscale.com/net/dns/manager.go b/vendor/tailscale.com/net/dns/manager.go index d1aa73ca67..cee7d7ede8 100644 --- a/vendor/tailscale.com/net/dns/manager.go +++ b/vendor/tailscale.com/net/dns/manager.go @@ -12,10 +12,11 @@ import ( "net" "net/netip" "runtime" + "slices" + "strings" "sync/atomic" "time" - "golang.org/x/exp/slices" "tailscale.com/health" "tailscale.com/net/dns/resolver" "tailscale.com/net/netmon" @@ -96,7 +97,9 @@ func (m *Manager) Set(cfg Config) error { m.logf("Resolvercfg: %v", logger.ArgWriter(func(w *bufio.Writer) { rcfg.WriteToBufioWriter(w) })) - m.logf("OScfg: %+v", ocfg) + m.logf("OScfg: %v", logger.ArgWriter(func(w *bufio.Writer) { + ocfg.WriteToBufioWriter(w) + })) if err := m.resolver.SetConfig(rcfg); err != nil { return err @@ -139,14 +142,15 @@ func compileHostEntries(cfg Config) (hosts []*HostEntry) { } } } - slices.SortFunc(hosts, func(a, b *HostEntry) bool { - if len(a.Hosts) == 0 { - return false - } - if len(b.Hosts) == 0 { - return true + slices.SortFunc(hosts, func(a, b *HostEntry) int { + if len(a.Hosts) == 0 && len(b.Hosts) == 0 { + return 0 + } else if len(a.Hosts) == 0 { + return -1 + } else if len(b.Hosts) == 0 { + return 1 } - return a.Hosts[0] < b.Hosts[0] + return strings.Compare(a.Hosts[0], b.Hosts[0]) }) return hosts } diff --git a/vendor/tailscale.com/net/dns/manager_linux.go b/vendor/tailscale.com/net/dns/manager_linux.go index b61c362894..642fe99c50 100644 --- a/vendor/tailscale.com/net/dns/manager_linux.go +++ b/vendor/tailscale.com/net/dns/manager_linux.go @@ -139,7 +139,7 @@ func dnsMode(logf logger.Logf, env newOSConfigEnv) (ret string, err error) { // header, but doesn't actually point to resolved. We mustn't // try to program resolved in that case. // https://github.com/tailscale/tailscale/issues/2136 - if err := resolvedIsActuallyResolver(bs); err != nil { + if err := resolvedIsActuallyResolver(logf, env, dbg, bs); err != nil { logf("dns: resolvedIsActuallyResolver error: %v", err) dbg("resolved", "not-in-use") return "direct", nil @@ -225,7 +225,7 @@ func dnsMode(logf logger.Logf, env newOSConfigEnv) (ret string, err error) { dbg("rc", "nm") // Sometimes, NetworkManager owns the configuration but points // it at systemd-resolved. - if err := resolvedIsActuallyResolver(bs); err != nil { + if err := resolvedIsActuallyResolver(logf, env, dbg, bs); err != nil { logf("dns: resolvedIsActuallyResolver error: %v", err) dbg("resolved", "not-in-use") // You'd think we would use newNMManager here. However, as @@ -318,14 +318,23 @@ func nmIsUsingResolved() error { return nil } -// resolvedIsActuallyResolver reports whether the given resolv.conf -// bytes describe a configuration where systemd-resolved (127.0.0.53) -// is the only configured nameserver. +// resolvedIsActuallyResolver reports whether the system is using +// systemd-resolved as the resolver. There are two different ways to +// use systemd-resolved: +// - libnss_resolve, which requires adding `resolve` to the "hosts:" +// line in /etc/nsswitch.conf +// - setting the only nameserver configured in `resolv.conf` to +// systemd-resolved IP (127.0.0.53) // // Returns an error if the configuration is something other than // exclusively systemd-resolved, or nil if the config is only // systemd-resolved. -func resolvedIsActuallyResolver(bs []byte) error { +func resolvedIsActuallyResolver(logf logger.Logf, env newOSConfigEnv, dbg func(k, v string), bs []byte) error { + if err := isLibnssResolveUsed(env); err == nil { + dbg("resolved", "nss") + return nil + } + cfg, err := readResolv(bytes.NewBuffer(bs)) if err != nil { return err @@ -342,9 +351,34 @@ func resolvedIsActuallyResolver(bs []byte) error { return fmt.Errorf("resolv.conf doesn't point to systemd-resolved; points to %v", cfg.Nameservers) } } + dbg("resolved", "file") return nil } +// isLibnssResolveUsed reports whether libnss_resolve is used +// for resolving names. Returns nil if it is, and an error otherwise. +func isLibnssResolveUsed(env newOSConfigEnv) error { + bs, err := env.fs.ReadFile("/etc/nsswitch.conf") + if err != nil { + return fmt.Errorf("reading /etc/resolv.conf: %w", err) + } + for _, line := range strings.Split(string(bs), "\n") { + fields := strings.Fields(line) + if len(fields) < 2 || fields[0] != "hosts:" { + continue + } + for _, module := range fields[1:] { + if module == "dns" { + return fmt.Errorf("dns with a higher priority than libnss_resolve") + } + if module == "resolve" { + return nil + } + } + } + return fmt.Errorf("libnss_resolve not used") +} + func dbusPing(name, objectPath string) error { conn, err := dbus.SystemBus() if err != nil { diff --git a/vendor/tailscale.com/net/dns/nm.go b/vendor/tailscale.com/net/dns/nm.go index 664297c630..4d9fbca669 100644 --- a/vendor/tailscale.com/net/dns/nm.go +++ b/vendor/tailscale.com/net/dns/nm.go @@ -8,13 +8,14 @@ package dns import ( "context" "fmt" + "net" "net/netip" "sort" "time" "github.com/godbus/dbus/v5" "github.com/josharian/native" - "tailscale.com/net/interfaces" + "tailscale.com/net/tsaddr" "tailscale.com/util/dnsname" ) @@ -139,14 +140,18 @@ func (m *nmManager) trySet(ctx context.Context, config OSConfig) error { // tell it explicitly to keep it. Read out the current interface // settings and mirror them out to NetworkManager. var addrs6 []map[string]any - addrs, _, err := interfaces.Tailscale() - if err == nil { + if tsIf, err := net.InterfaceByName(m.interfaceName); err == nil { + addrs, _ := tsIf.Addrs() for _, a := range addrs { - if a.Is6() { - addrs6 = append(addrs6, map[string]any{ - "address": a.String(), - "prefix": uint32(128), - }) + if ipnet, ok := a.(*net.IPNet); ok { + nip, ok := netip.AddrFromSlice(ipnet.IP) + nip = nip.Unmap() + if ok && tsaddr.IsTailscaleIP(nip) && nip.Is6() { + addrs6 = append(addrs6, map[string]any{ + "address": nip.String(), + "prefix": uint32(128), + }) + } } } } diff --git a/vendor/tailscale.com/net/dns/nrpt_windows.go b/vendor/tailscale.com/net/dns/nrpt_windows.go index f81cdb42f3..78a7026160 100644 --- a/vendor/tailscale.com/net/dns/nrpt_windows.go +++ b/vendor/tailscale.com/net/dns/nrpt_windows.go @@ -13,6 +13,7 @@ import ( "golang.org/x/sys/windows/registry" "tailscale.com/types/logger" "tailscale.com/util/dnsname" + "tailscale.com/util/set" "tailscale.com/util/winutil" ) @@ -158,14 +159,14 @@ func (db *nrptRuleDatabase) detectWriteAsGP() { } // Add *all* rules from the GP subkey into a set. - gpSubkeyMap := make(map[string]struct{}, len(gpSubkeyNames)) + gpSubkeyMap := make(set.Set[string], len(gpSubkeyNames)) for _, gpSubkey := range gpSubkeyNames { - gpSubkeyMap[strings.ToUpper(gpSubkey)] = struct{}{} + gpSubkeyMap.Add(strings.ToUpper(gpSubkey)) } // Remove *our* rules from the set. for _, ourRuleID := range db.ruleIDs { - delete(gpSubkeyMap, strings.ToUpper(ourRuleID)) + gpSubkeyMap.Delete(strings.ToUpper(ourRuleID)) } // Any leftover rules do not belong to us. When group policy is being used diff --git a/vendor/tailscale.com/net/dns/osconfig.go b/vendor/tailscale.com/net/dns/osconfig.go index b12e6418b6..f3c9016a76 100644 --- a/vendor/tailscale.com/net/dns/osconfig.go +++ b/vendor/tailscale.com/net/dns/osconfig.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "net/netip" + "strings" "tailscale.com/types/logger" "tailscale.com/util/dnsname" @@ -65,6 +66,42 @@ type OSConfig struct { MatchDomains []dnsname.FQDN } +func (o *OSConfig) WriteToBufioWriter(w *bufio.Writer) { + if o == nil { + w.WriteString("") + return + } + w.WriteString("{") + if len(o.Hosts) > 0 { + fmt.Fprintf(w, "Hosts:%v ", o.Hosts) + } + if len(o.Nameservers) > 0 { + fmt.Fprintf(w, "Nameservers:%v ", o.Nameservers) + } + if len(o.SearchDomains) > 0 { + fmt.Fprintf(w, "SearchDomains:%v ", o.SearchDomains) + } + if len(o.MatchDomains) > 0 { + w.WriteString("SearchDomains:[") + sp := "" + var numARPA int + for _, s := range o.MatchDomains { + if strings.HasSuffix(string(s), ".arpa.") { + numARPA++ + continue + } + w.WriteString(sp) + w.WriteString(string(s)) + sp = " " + } + w.WriteString("]") + if numARPA > 0 { + fmt.Fprintf(w, "+%darpa", numARPA) + } + } + w.WriteString("}") +} + func (o OSConfig) IsZero() bool { return len(o.Nameservers) == 0 && len(o.SearchDomains) == 0 && len(o.MatchDomains) == 0 } diff --git a/vendor/tailscale.com/net/dns/publicdns/publicdns.go b/vendor/tailscale.com/net/dns/publicdns/publicdns.go index 806dea431e..40c1b0f8f8 100644 --- a/vendor/tailscale.com/net/dns/publicdns/publicdns.go +++ b/vendor/tailscale.com/net/dns/publicdns/publicdns.go @@ -129,9 +129,15 @@ func addDoH(ipStr, base string) { dohIPsOfBase[base] = append(dohIPsOfBase[base], ip) } +const ( + wikimediaDNSv4 = "185.71.138.138" + wikimediaDNSv6 = "2001:67c:930::1" +) + // populate is called once to initialize the knownDoH and dohIPsOfBase maps. func populate() { // Cloudflare + // https://developers.cloudflare.com/1.1.1.1/ip-addresses/ addDoH("1.1.1.1", "https://cloudflare-dns.com/dns-query") addDoH("1.0.0.1", "https://cloudflare-dns.com/dns-query") addDoH("2606:4700:4700::1111", "https://cloudflare-dns.com/dns-query") @@ -165,10 +171,17 @@ func populate() { // addDoH("208.67.220.123", "https://doh.familyshield.opendns.com/dns-query") // Quad9 + // https://www.quad9.net/service/service-addresses-and-features addDoH("9.9.9.9", "https://dns.quad9.net/dns-query") addDoH("149.112.112.112", "https://dns.quad9.net/dns-query") addDoH("2620:fe::fe", "https://dns.quad9.net/dns-query") - addDoH("2620:fe::fe:9", "https://dns.quad9.net/dns-query") + addDoH("2620:fe::9", "https://dns.quad9.net/dns-query") + + // Quad9 +ECS +DNSSEC + addDoH("9.9.9.11", "https://dns11.quad9.net/dns-query") + addDoH("149.112.112.11", "https://dns11.quad9.net/dns-query") + addDoH("2620:fe::11", "https://dns11.quad9.net/dns-query") + addDoH("2620:fe::fe:11", "https://dns11.quad9.net/dns-query") // Quad9 -DNSSEC addDoH("9.9.9.10", "https://dns10.quad9.net/dns-query") @@ -177,14 +190,26 @@ func populate() { addDoH("2620:fe::fe:10", "https://dns10.quad9.net/dns-query") // Mullvad - addDoH("194.242.2.2", "https://doh.mullvad.net/dns-query") - addDoH("193.19.108.2", "https://doh.mullvad.net/dns-query") - addDoH("2a07:e340::2", "https://doh.mullvad.net/dns-query") - - // Mullvad -Ads - addDoH("194.242.2.3", "https://adblock.doh.mullvad.net/dns-query") - addDoH("193.19.108.3", "https://adblock.doh.mullvad.net/dns-query") - addDoH("2a07:e340::3", "https://adblock.doh.mullvad.net/dns-query") + // See https://mullvad.net/en/help/dns-over-https-and-dns-over-tls/ + // Mullvad (default) + addDoH("194.242.2.2", "https://dns.mullvad.net/dns-query") + addDoH("2a07:e340::2", "https://dns.mullvad.net/dns-query") + // Mullvad (adblock) + addDoH("194.242.2.3", "https://adblock.dns.mullvad.net/dns-query") + addDoH("2a07:e340::3", "https://adblock.dns.mullvad.net/dns-query") + // Mullvad (base) + addDoH("194.242.2.4", "https://base.dns.mullvad.net/dns-query") + addDoH("2a07:e340::4", "https://base.dns.mullvad.net/dns-query") + // Mullvad (extended) + addDoH("194.242.2.5", "https://extended.dns.mullvad.net/dns-query") + addDoH("2a07:e340::5", "https://extended.dns.mullvad.net/dns-query") + // Mullvad (all) + addDoH("194.242.2.9", "https://all.dns.mullvad.net/dns-query") + addDoH("2a07:e340::9", "https://all.dns.mullvad.net/dns-query") + + // Wikimedia + addDoH(wikimediaDNSv4, "https://wikimedia-dns.org/dns-query") + addDoH(wikimediaDNSv6, "https://wikimedia-dns.org/dns-query") } var ( @@ -207,6 +232,10 @@ var ( nextDNSv4RangeB = netip.MustParsePrefix("45.90.30.0/24") nextDNSv4One = nextDNSv4RangeA.Addr() nextDNSv4Two = nextDNSv4RangeB.Addr() + + // Wikimedia DNS server IPs (anycast) + wikimediaDNSv4Addr = netip.MustParseAddr(wikimediaDNSv4) + wikimediaDNSv6Addr = netip.MustParseAddr(wikimediaDNSv6) ) // nextDNSv6Gen generates a NextDNS IPv6 address from the upper 8 bytes in the @@ -224,5 +253,6 @@ func nextDNSv6Gen(ip netip.Addr, id []byte) netip.Addr { // DNS-over-HTTPS (not regular port 53 DNS). func IPIsDoHOnlyServer(ip netip.Addr) bool { return nextDNSv6RangeA.Contains(ip) || nextDNSv6RangeB.Contains(ip) || - nextDNSv4RangeA.Contains(ip) || nextDNSv4RangeB.Contains(ip) + nextDNSv4RangeA.Contains(ip) || nextDNSv4RangeB.Contains(ip) || + ip == wikimediaDNSv4Addr || ip == wikimediaDNSv6Addr } diff --git a/vendor/tailscale.com/net/dns/recursive/recursive.go b/vendor/tailscale.com/net/dns/recursive/recursive.go new file mode 100644 index 0000000000..5b585483ca --- /dev/null +++ b/vendor/tailscale.com/net/dns/recursive/recursive.go @@ -0,0 +1,628 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package recursive implements a simple recursive DNS resolver. +package recursive + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "slices" + "strings" + "time" + + "github.com/miekg/dns" + "tailscale.com/envknob" + "tailscale.com/net/netns" + "tailscale.com/types/logger" + "tailscale.com/util/dnsname" + "tailscale.com/util/mak" + "tailscale.com/util/multierr" + "tailscale.com/util/slicesx" +) + +const ( + // maxDepth is how deep from the root nameservers we'll recurse when + // resolving; passing this limit will instead return an error. + // + // maxDepth must be at least 20 to resolve "console.aws.amazon.com", + // which is a domain with a moderately complicated DNS setup. The + // current value of 30 was chosen semi-arbitrarily to ensure that we + // have about 50% headroom. + maxDepth = 30 + // numStartingServers is the number of root nameservers that we use as + // initial candidates for our recursion. + numStartingServers = 3 + // udpQueryTimeout is the amount of time we wait for a UDP response + // from a nameserver before falling back to a TCP connection. + udpQueryTimeout = 5 * time.Second + + // These constants aren't typed in the DNS package, so we create typed + // versions here to avoid having to do repeated type casts. + qtypeA dns.Type = dns.Type(dns.TypeA) + qtypeAAAA dns.Type = dns.Type(dns.TypeAAAA) +) + +var ( + // ErrMaxDepth is returned when recursive resolving exceeds the maximum + // depth limit for this package. + ErrMaxDepth = fmt.Errorf("exceeded max depth %d when resolving", maxDepth) + + // ErrAuthoritativeNoResponses is the error returned when an + // authoritative nameserver indicates that there are no responses to + // the given query. + ErrAuthoritativeNoResponses = errors.New("authoritative server returned no responses") + + // ErrNoResponses is returned when our resolution process completes + // with no valid responses from any nameserver, but no authoritative + // server explicitly returned NXDOMAIN. + ErrNoResponses = errors.New("no responses to query") +) + +var rootServersV4 = []netip.Addr{ + netip.MustParseAddr("198.41.0.4"), // a.root-servers.net + netip.MustParseAddr("199.9.14.201"), // b.root-servers.net + netip.MustParseAddr("192.33.4.12"), // c.root-servers.net + netip.MustParseAddr("199.7.91.13"), // d.root-servers.net + netip.MustParseAddr("192.203.230.10"), // e.root-servers.net + netip.MustParseAddr("192.5.5.241"), // f.root-servers.net + netip.MustParseAddr("192.112.36.4"), // g.root-servers.net + netip.MustParseAddr("198.97.190.53"), // h.root-servers.net + netip.MustParseAddr("192.36.148.17"), // i.root-servers.net + netip.MustParseAddr("192.58.128.30"), // j.root-servers.net + netip.MustParseAddr("193.0.14.129"), // k.root-servers.net + netip.MustParseAddr("199.7.83.42"), // l.root-servers.net + netip.MustParseAddr("202.12.27.33"), // m.root-servers.net +} + +var rootServersV6 = []netip.Addr{ + netip.MustParseAddr("2001:503:ba3e::2:30"), // a.root-servers.net + netip.MustParseAddr("2001:500:200::b"), // b.root-servers.net + netip.MustParseAddr("2001:500:2::c"), // c.root-servers.net + netip.MustParseAddr("2001:500:2d::d"), // d.root-servers.net + netip.MustParseAddr("2001:500:a8::e"), // e.root-servers.net + netip.MustParseAddr("2001:500:2f::f"), // f.root-servers.net + netip.MustParseAddr("2001:500:12::d0d"), // g.root-servers.net + netip.MustParseAddr("2001:500:1::53"), // h.root-servers.net + netip.MustParseAddr("2001:7fe::53"), // i.root-servers.net + netip.MustParseAddr("2001:503:c27::2:30"), // j.root-servers.net + netip.MustParseAddr("2001:7fd::1"), // k.root-servers.net + netip.MustParseAddr("2001:500:9f::42"), // l.root-servers.net + netip.MustParseAddr("2001:dc3::35"), // m.root-servers.net +} + +var debug = envknob.RegisterBool("TS_DEBUG_RECURSIVE_DNS") + +// Resolver is a recursive DNS resolver that is designed for looking up A and AAAA records. +type Resolver struct { + // Dialer is used to create outbound connections. If nil, a zero + // net.Dialer will be used instead. + Dialer netns.Dialer + + // Logf is the logging function to use; if none is specified, then logs + // will be dropped. + Logf logger.Logf + + // NoIPv6, if set, will prevent this package from querying for AAAA + // records and will avoid contacting nameservers over IPv6. + NoIPv6 bool + + // Test mocks + testQueryHook func(name dnsname.FQDN, nameserver netip.Addr, protocol string, qtype dns.Type) (*dns.Msg, error) + testExchangeHook func(nameserver netip.Addr, network string, msg *dns.Msg) (*dns.Msg, error) + rootServers []netip.Addr + timeNow func() time.Time + + // Caching + // NOTE(andrew): if we make resolution parallel, this needs a mutex + queryCache map[dnsQuery]dnsMsgWithExpiry + + // Possible future additions: + // - Additional nameservers? From the system maybe? + // - NoIPv4 for IPv4 + // - DNS-over-HTTPS or DNS-over-TLS support +} + +// queryState stores all state during the course of a single query +type queryState struct { + // rootServers are the root nameservers to start from + rootServers []netip.Addr + + // TODO: metrics? +} + +type dnsQuery struct { + nameserver netip.Addr + name dnsname.FQDN + qtype dns.Type +} + +func (q dnsQuery) String() string { + return fmt.Sprintf("dnsQuery{nameserver:%q,name:%q,qtype:%v}", q.nameserver.String(), q.name, q.qtype) +} + +type dnsMsgWithExpiry struct { + *dns.Msg + expiresAt time.Time +} + +func (r *Resolver) now() time.Time { + if r.timeNow != nil { + return r.timeNow() + } + return time.Now() +} + +func (r *Resolver) logf(format string, args ...any) { + if r.Logf == nil { + return + } + r.Logf(format, args...) +} + +func (r *Resolver) dlogf(format string, args ...any) { + if r.Logf == nil || !debug() { + return + } + r.Logf(format, args...) +} + +func (r *Resolver) depthlogf(depth int, format string, args ...any) { + if r.Logf == nil || !debug() { + return + } + prefix := fmt.Sprintf("[%d] %s", depth, strings.Repeat(" ", depth)) + r.Logf(prefix+format, args...) +} + +var defaultDialer net.Dialer + +func (r *Resolver) dialer() netns.Dialer { + if r.Dialer != nil { + return r.Dialer + } + + return &defaultDialer +} + +func (r *Resolver) newState() *queryState { + var rootServers []netip.Addr + if len(r.rootServers) > 0 { + rootServers = r.rootServers + } else { + // Select a random subset of root nameservers to start from, since if + // we don't get responses from those, something else has probably gone + // horribly wrong. + roots4 := slices.Clone(rootServersV4) + slicesx.Shuffle(roots4) + roots4 = roots4[:numStartingServers] + + var roots6 []netip.Addr + if !r.NoIPv6 { + roots6 = slices.Clone(rootServersV6) + slicesx.Shuffle(roots6) + roots6 = roots6[:numStartingServers] + } + + // Interleave the root servers so that we try to contact them over + // IPv4, then IPv6, IPv4, IPv6, etc. + rootServers = slicesx.Interleave(roots4, roots6) + } + + return &queryState{ + rootServers: rootServers, + } +} + +// Resolve will perform a recursive DNS resolution for the provided name, +// starting at a randomly-chosen root DNS server, and return the A and AAAA +// responses as a slice of netip.Addrs along with the minimum TTL for the +// returned records. +func (r *Resolver) Resolve(ctx context.Context, name string) (addrs []netip.Addr, minTTL time.Duration, err error) { + dnsName, err := dnsname.ToFQDN(name) + if err != nil { + return nil, 0, err + } + + qstate := r.newState() + + r.logf("querying IPv4 addresses for: %q", name) + addrs4, minTTL4, err4 := r.resolveRecursiveFromRoot(ctx, qstate, 0, dnsName, qtypeA) + + var ( + addrs6 []netip.Addr + minTTL6 time.Duration + err6 error + ) + if !r.NoIPv6 { + r.logf("querying IPv6 addresses for: %q", name) + addrs6, minTTL6, err6 = r.resolveRecursiveFromRoot(ctx, qstate, 0, dnsName, qtypeAAAA) + } + + if err4 != nil && err6 != nil { + if err4 == err6 { + return nil, 0, err4 + } + + return nil, 0, multierr.New(err4, err6) + } + if err4 != nil { + return addrs6, minTTL6, nil + } else if err6 != nil { + return addrs4, minTTL4, nil + } + + minTTL = minTTL4 + if minTTL6 < minTTL { + minTTL = minTTL6 + } + + addrs = append(addrs4, addrs6...) + if len(addrs) == 0 { + return nil, 0, ErrNoResponses + } + + slicesx.Shuffle(addrs) + return addrs, minTTL, nil +} + +func (r *Resolver) resolveRecursiveFromRoot( + ctx context.Context, + qstate *queryState, + depth int, + name dnsname.FQDN, // what we're querying + qtype dns.Type, +) ([]netip.Addr, time.Duration, error) { + r.depthlogf(depth, "resolving %q from root (type: %v)", name, qtype) + + var depthError bool + for _, server := range qstate.rootServers { + addrs, minTTL, err := r.resolveRecursive(ctx, qstate, depth, name, server, qtype) + if err == nil { + return addrs, minTTL, err + } else if errors.Is(err, ErrAuthoritativeNoResponses) { + return nil, 0, ErrAuthoritativeNoResponses + } else if errors.Is(err, ErrMaxDepth) { + depthError = true + } + } + + if depthError { + return nil, 0, ErrMaxDepth + } + return nil, 0, ErrNoResponses +} + +func (r *Resolver) resolveRecursive( + ctx context.Context, + qstate *queryState, + depth int, + name dnsname.FQDN, // what we're querying + nameserver netip.Addr, + qtype dns.Type, +) ([]netip.Addr, time.Duration, error) { + if depth == maxDepth { + r.depthlogf(depth, "not recursing past maximum depth") + return nil, 0, ErrMaxDepth + } + + // Ask this nameserver for an answer. + resp, err := r.queryNameserver(ctx, depth, name, nameserver, qtype) + if err != nil { + return nil, 0, err + } + + // If we get an actual answer from the nameserver, then return it. + var ( + answers []netip.Addr + cnames []dnsname.FQDN + minTTL = 24 * 60 * 60 // 24 hours in seconds + ) + for _, answer := range resp.Answer { + if crec, ok := answer.(*dns.CNAME); ok { + cnameFQDN, err := dnsname.ToFQDN(crec.Target) + if err != nil { + r.logf("bad CNAME %q returned: %v", crec.Target, err) + continue + } + + cnames = append(cnames, cnameFQDN) + continue + } + + addr := addrFromRecord(answer) + if !addr.IsValid() { + r.logf("[unexpected] invalid record in %T answer", answer) + } else if addr.Is4() && qtype != qtypeA { + r.logf("[unexpected] got IPv4 answer but qtype=%v", qtype) + } else if addr.Is6() && qtype != qtypeAAAA { + r.logf("[unexpected] got IPv6 answer but qtype=%v", qtype) + } else { + answers = append(answers, addr) + minTTL = min(minTTL, int(answer.Header().Ttl)) + } + } + + if len(answers) > 0 { + r.depthlogf(depth, "got answers for %q: %v", name, answers) + return answers, time.Duration(minTTL) * time.Second, nil + } + + r.depthlogf(depth, "no answers for %q", name) + + // If we have a non-zero number of CNAMEs, then try resolving those + // (from the root again) and return the first one that succeeds. + // + // TODO: return the union of all responses? + // TODO: parallelism? + if len(cnames) > 0 { + r.depthlogf(depth, "got CNAME responses for %q: %v", name, cnames) + } + var cnameDepthError bool + for _, cname := range cnames { + answers, minTTL, err := r.resolveRecursiveFromRoot(ctx, qstate, depth+1, cname, qtype) + if err == nil { + return answers, minTTL, nil + } else if errors.Is(err, ErrAuthoritativeNoResponses) { + return nil, 0, ErrAuthoritativeNoResponses + } else if errors.Is(err, ErrMaxDepth) { + cnameDepthError = true + } + } + + // If this is an authoritative response, then we know that continuing + // to look further is not going to result in any answers and we should + // bail out. + if resp.MsgHdr.Authoritative { + // If we failed to recurse into a CNAME due to a depth limit, + // propagate that here. + if cnameDepthError { + return nil, 0, ErrMaxDepth + } + + r.depthlogf(depth, "got authoritative response with no answers; stopping") + return nil, 0, ErrAuthoritativeNoResponses + } + + r.depthlogf(depth, "got %d NS responses and %d ADDITIONAL responses for %q", len(resp.Ns), len(resp.Extra), name) + + // No CNAMEs and no answers; see if we got any AUTHORITY responses, + // which indicate which nameservers to query next. + var authorities []dnsname.FQDN + for _, rr := range resp.Ns { + ns, ok := rr.(*dns.NS) + if !ok { + continue + } + + nsName, err := dnsname.ToFQDN(ns.Ns) + if err != nil { + r.logf("unexpected bad NS name %q: %v", ns.Ns, err) + continue + } + + authorities = append(authorities, nsName) + } + + // Also check for "glue" records, which are IP addresses provided by + // the DNS server for authority responses; these are required when the + // authority server is a subdomain of what's being resolved. + glueRecords := make(map[dnsname.FQDN][]netip.Addr) + for _, rr := range resp.Extra { + name, err := dnsname.ToFQDN(rr.Header().Name) + if err != nil { + r.logf("unexpected bad Name %q in Extra addr: %v", rr.Header().Name, err) + continue + } + + if addr := addrFromRecord(rr); addr.IsValid() { + glueRecords[name] = append(glueRecords[name], addr) + } else { + r.logf("unexpected bad Extra %T addr", rr) + } + } + + // Try authorities with glue records first, to minimize the number of + // additional DNS queries that we need to make. + authoritiesGlue, authoritiesNoGlue := slicesx.Partition(authorities, func(aa dnsname.FQDN) bool { + return len(glueRecords[aa]) > 0 + }) + + authorityDepthError := false + + r.depthlogf(depth, "authorities with glue records for recursion: %v", authoritiesGlue) + for _, authority := range authoritiesGlue { + for _, nameserver := range glueRecords[authority] { + answers, minTTL, err := r.resolveRecursive(ctx, qstate, depth+1, name, nameserver, qtype) + if err == nil { + return answers, minTTL, nil + } else if errors.Is(err, ErrAuthoritativeNoResponses) { + return nil, 0, ErrAuthoritativeNoResponses + } else if errors.Is(err, ErrMaxDepth) { + authorityDepthError = true + } + } + } + + r.depthlogf(depth, "authorities with no glue records for recursion: %v", authoritiesNoGlue) + for _, authority := range authoritiesNoGlue { + // First, resolve the IP for the authority server from the + // root, querying for both IPv4 and IPv6 addresses regardless + // of what the current question type is. + // + // TODO: check for infinite recursion; it'll get caught by our + // recursion depth, but we want to bail early. + for _, authorityQtype := range []dns.Type{qtypeAAAA, qtypeA} { + answers, _, err := r.resolveRecursiveFromRoot(ctx, qstate, depth+1, authority, authorityQtype) + if err != nil { + r.depthlogf(depth, "error querying authority %q: %v", authority, err) + continue + } + r.depthlogf(depth, "resolved authority %q (type %v) to: %v", authority, authorityQtype, answers) + + // Now, query this authority for the final address. + for _, nameserver := range answers { + answers, minTTL, err := r.resolveRecursive(ctx, qstate, depth+1, name, nameserver, qtype) + if err == nil { + return answers, minTTL, nil + } else if errors.Is(err, ErrAuthoritativeNoResponses) { + return nil, 0, ErrAuthoritativeNoResponses + } else if errors.Is(err, ErrMaxDepth) { + authorityDepthError = true + } + } + } + } + + if authorityDepthError { + return nil, 0, ErrMaxDepth + } + return nil, 0, ErrNoResponses +} + +// queryNameserver sends a query for "name" to the nameserver "nameserver" for +// records of type "qtype", trying both UDP and TCP connections as +// appropriate. +func (r *Resolver) queryNameserver( + ctx context.Context, + depth int, + name dnsname.FQDN, // what we're querying + nameserver netip.Addr, // destination of query + qtype dns.Type, +) (*dns.Msg, error) { + // TODO(andrew): we should QNAME minimisation here to avoid sending the + // full name to intermediate/root nameservers. See: + // https://www.rfc-editor.org/rfc/rfc7816 + + // Handle the case where UDP is blocked by adding an explicit timeout + // for the UDP portion of this query. + udpCtx, udpCtxCancel := context.WithTimeout(ctx, udpQueryTimeout) + defer udpCtxCancel() + + msg, err := r.queryNameserverProto(udpCtx, depth, name, nameserver, "udp", qtype) + if err == nil { + return msg, nil + } + + msg, err2 := r.queryNameserverProto(ctx, depth, name, nameserver, "tcp", qtype) + if err2 == nil { + return msg, nil + } + + return nil, multierr.New(err, err2) +} + +// queryNameserverProto sends a query for "name" to the nameserver "nameserver" +// for records of type "qtype" over the provided protocol (either "udp" +// or "tcp"), and returns the DNS response or an error. +func (r *Resolver) queryNameserverProto( + ctx context.Context, + depth int, + name dnsname.FQDN, // what we're querying + nameserver netip.Addr, // destination of query + protocol string, + qtype dns.Type, +) (resp *dns.Msg, err error) { + if r.testQueryHook != nil { + return r.testQueryHook(name, nameserver, protocol, qtype) + } + + now := r.now() + nameserverStr := nameserver.String() + + cacheKey := dnsQuery{ + nameserver: nameserver, + name: name, + qtype: qtype, + } + cacheEntry, ok := r.queryCache[cacheKey] + if ok && cacheEntry.expiresAt.Before(now) { + r.depthlogf(depth, "using cached response from %s about %q (type: %v)", nameserverStr, name, qtype) + return cacheEntry.Msg, nil + } + + var network string + if nameserver.Is4() { + network = protocol + "4" + } else { + network = protocol + "6" + } + + // Prepare a message asking for an appropriately-typed record + // for the name we're querying. + m := new(dns.Msg) + m.SetQuestion(name.WithTrailingDot(), uint16(qtype)) + + // Allow mocking out the network components with our exchange hook. + if r.testExchangeHook != nil { + resp, err = r.testExchangeHook(nameserver, network, m) + } else { + // Dial the current nameserver using our dialer. + var nconn net.Conn + nconn, err = r.dialer().DialContext(ctx, network, net.JoinHostPort(nameserverStr, "53")) + if err != nil { + return nil, err + } + + var c dns.Client // TODO: share? + conn := &dns.Conn{ + Conn: nconn, + UDPSize: c.UDPSize, + } + + // Send the DNS request to the current nameserver. + r.depthlogf(depth, "asking %s over %s about %q (type: %v)", nameserverStr, protocol, name, qtype) + resp, _, err = c.ExchangeWithConnContext(ctx, m, conn) + } + if err != nil { + return nil, err + } + + // If the message was truncated and we're using UDP, re-run with TCP. + if resp.MsgHdr.Truncated && protocol == "udp" { + r.depthlogf(depth, "response message truncated; re-running query with TCP") + resp, err = r.queryNameserverProto(ctx, depth, name, nameserver, "tcp", qtype) + if err != nil { + return nil, err + } + } + + // Find minimum expiry for all records in this message. + var minTTL int + for _, rr := range resp.Answer { + minTTL = min(minTTL, int(rr.Header().Ttl)) + } + for _, rr := range resp.Ns { + minTTL = min(minTTL, int(rr.Header().Ttl)) + } + for _, rr := range resp.Extra { + minTTL = min(minTTL, int(rr.Header().Ttl)) + } + + mak.Set(&r.queryCache, cacheKey, dnsMsgWithExpiry{ + Msg: resp, + expiresAt: now.Add(time.Duration(minTTL) * time.Second), + }) + return resp, nil +} + +func addrFromRecord(rr dns.RR) netip.Addr { + switch v := rr.(type) { + case *dns.A: + ip, ok := netip.AddrFromSlice(v.A) + if !ok || !ip.Is4() { + return netip.Addr{} + } + return ip + case *dns.AAAA: + ip, ok := netip.AddrFromSlice(v.AAAA) + if !ok || !ip.Is6() { + return netip.Addr{} + } + return ip + } + return netip.Addr{} +} diff --git a/vendor/tailscale.com/net/dns/resolvconfpath_default.go b/vendor/tailscale.com/net/dns/resolvconfpath_default.go new file mode 100644 index 0000000000..57e82c4c77 --- /dev/null +++ b/vendor/tailscale.com/net/dns/resolvconfpath_default.go @@ -0,0 +1,11 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !gokrazy + +package dns + +const ( + resolvConf = "/etc/resolv.conf" + backupConf = "/etc/resolv.pre-tailscale-backup.conf" +) diff --git a/vendor/tailscale.com/net/dns/resolvconfpath_gokrazy.go b/vendor/tailscale.com/net/dns/resolvconfpath_gokrazy.go new file mode 100644 index 0000000000..f0759b0e31 --- /dev/null +++ b/vendor/tailscale.com/net/dns/resolvconfpath_gokrazy.go @@ -0,0 +1,11 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build gokrazy + +package dns + +const ( + resolvConf = "/tmp/resolv.conf" + backupConf = "/tmp/resolv.pre-tailscale-backup.conf" +) diff --git a/vendor/tailscale.com/net/dns/resolver/tsdns.go b/vendor/tailscale.com/net/dns/resolver/tsdns.go index 4ed0e57ba0..2b5a0869e3 100644 --- a/vendor/tailscale.com/net/dns/resolver/tsdns.go +++ b/vendor/tailscale.com/net/dns/resolver/tsdns.go @@ -724,7 +724,7 @@ func (r *Resolver) parseViaDomain(domain dnsname.FQDN, typ dns.Type) (netip.Addr return netip.Addr{}, false // badly formed, don't respond } - // MapVia will never error when given an ipv4 netip.Prefix. + // MapVia will never error when given an IPv4 netip.Prefix. out, _ := tsaddr.MapVia(uint32(prefix), netip.PrefixFrom(ip4, ip4.BitLen())) return out.Addr(), true } diff --git a/vendor/tailscale.com/net/dnsfallback/dnsfallback.go b/vendor/tailscale.com/net/dnsfallback/dnsfallback.go index 5e80c79c9f..de58fa38cc 100644 --- a/vendor/tailscale.com/net/dnsfallback/dnsfallback.go +++ b/vendor/tailscale.com/net/dnsfallback/dnsfallback.go @@ -19,25 +19,89 @@ import ( "net/url" "os" "reflect" + "slices" "sync/atomic" "time" + "go4.org/netipx" "tailscale.com/atomicfile" + "tailscale.com/envknob" + "tailscale.com/net/dns/recursive" "tailscale.com/net/netmon" "tailscale.com/net/netns" "tailscale.com/net/tlsdial" "tailscale.com/net/tshttpproxy" "tailscale.com/tailcfg" "tailscale.com/types/logger" + "tailscale.com/util/clientmetric" "tailscale.com/util/slicesx" ) +var disableRecursiveResolver = envknob.RegisterBool("TS_DNSFALLBACK_DISABLE_RECURSIVE_RESOLVER") + // MakeLookupFunc creates a function that can be used to resolve hostnames // (e.g. as a LookupIPFallback from dnscache.Resolver). // The netMon parameter is optional; if non-nil it's used to do faster interface lookups. func MakeLookupFunc(logf logger.Logf, netMon *netmon.Monitor) func(ctx context.Context, host string) ([]netip.Addr, error) { return func(ctx context.Context, host string) ([]netip.Addr, error) { - return lookup(ctx, host, logf, netMon) + if disableRecursiveResolver() { + return lookup(ctx, host, logf, netMon) + } + + addrsCh := make(chan []netip.Addr, 1) + + // Run the recursive resolver in the background so we can + // compare the results. + go func() { + logf := logger.WithPrefix(logf, "recursive: ") + + // Ensure that we catch panics while we're testing this + // code path; this should never panic, but we don't + // want to take down the process by having the panic + // propagate to the top of the goroutine's stack and + // then terminate. + defer func() { + if r := recover(); r != nil { + logf("bootstrap DNS: recovered panic: %v", r) + metricRecursiveErrors.Add(1) + } + }() + + resolver := recursive.Resolver{ + Dialer: netns.NewDialer(logf, netMon), + Logf: logf, + } + addrs, minTTL, err := resolver.Resolve(ctx, host) + if err != nil { + logf("error using recursive resolver: %v", err) + metricRecursiveErrors.Add(1) + return + } + slices.SortFunc(addrs, netipx.CompareAddr) + + // Wait for a response from the main function + oldAddrs := <-addrsCh + slices.SortFunc(oldAddrs, netipx.CompareAddr) + + matches := slices.Equal(addrs, oldAddrs) + + logf("bootstrap DNS comparison: matches=%v oldAddrs=%v addrs=%v minTTL=%v", matches, oldAddrs, addrs, minTTL) + + if matches { + metricRecursiveMatches.Add(1) + } else { + metricRecursiveMismatches.Add(1) + } + }() + + addrs, err := lookup(ctx, host, logf, netMon) + if err != nil { + addrsCh <- nil + return nil, err + } + + addrsCh <- slices.Clone(addrs) + return addrs, nil } } @@ -254,3 +318,9 @@ func SetCachePath(path string, logf logger.Logf) { cachedDERPMap.Store(dm) logf("[v2] dnsfallback: SetCachePath loaded cached DERP map") } + +var ( + metricRecursiveMatches = clientmetric.NewCounter("dnsfallback_recursive_matches") + metricRecursiveMismatches = clientmetric.NewCounter("dnsfallback_recursive_mismatches") + metricRecursiveErrors = clientmetric.NewCounter("dnsfallback_recursive_errors") +) diff --git a/vendor/tailscale.com/net/interfaces/interfaces.go b/vendor/tailscale.com/net/interfaces/interfaces.go index e9e21eabca..6fe24c6512 100644 --- a/vendor/tailscale.com/net/interfaces/interfaces.go +++ b/vendor/tailscale.com/net/interfaces/interfaces.go @@ -11,6 +11,7 @@ import ( "net/http" "net/netip" "runtime" + "slices" "sort" "strings" @@ -24,49 +25,6 @@ import ( // which HTTP proxy the system should use. var LoginEndpointForProxyDetermination = "https://controlplane.tailscale.com/" -// Tailscale returns the current machine's Tailscale interface, if any. -// If none is found, all zero values are returned. -// A non-nil error is only returned on a problem listing the system interfaces. -func Tailscale() ([]netip.Addr, *net.Interface, error) { - ifs, err := netInterfaces() - if err != nil { - return nil, nil, err - } - for _, iface := range ifs { - if !maybeTailscaleInterfaceName(iface.Name) { - continue - } - addrs, err := iface.Addrs() - if err != nil { - continue - } - var tsIPs []netip.Addr - for _, a := range addrs { - if ipnet, ok := a.(*net.IPNet); ok { - nip, ok := netip.AddrFromSlice(ipnet.IP) - nip = nip.Unmap() - if ok && tsaddr.IsTailscaleIP(nip) { - tsIPs = append(tsIPs, nip) - } - } - } - if len(tsIPs) > 0 { - return tsIPs, iface.Interface, nil - } - } - return nil, nil, nil -} - -// maybeTailscaleInterfaceName reports whether s is an interface -// name that might be used by Tailscale. -func maybeTailscaleInterfaceName(s string) bool { - return s == "Tailscale" || - strings.HasPrefix(s, "wg") || - strings.HasPrefix(s, "ts") || - strings.HasPrefix(s, "tailscale") || - strings.HasPrefix(s, "utun") -} - func isUp(nif *net.Interface) bool { return nif.Flags&net.FlagUp != 0 } func isLoopback(nif *net.Interface) bool { return nif.Flags&net.FlagLoopback != 0 } @@ -300,9 +258,9 @@ func (s *State) String() string { } } sb.WriteString("ifs={") - ifs := make([]string, 0, len(s.Interface)) + var ifs []string for k := range s.Interface { - if anyInterestingIP(s.InterfaceIPs[k]) { + if s.keepInterfaceInStringSummary(k) { ifs = append(ifs, k) } } @@ -318,23 +276,40 @@ func (s *State) String() string { if i > 0 { sb.WriteString(" ") } - if s.Interface[ifName].IsUp() { - fmt.Fprintf(&sb, "%s:[", ifName) - needSpace := false - for _, pfx := range s.InterfaceIPs[ifName] { - if !isInterestingIP(pfx.Addr()) { - continue - } - if needSpace { - sb.WriteString(" ") - } + iface := s.Interface[ifName] + if iface.Interface == nil { + fmt.Fprintf(&sb, "%s:nil", ifName) + continue + } + if !iface.IsUp() { + fmt.Fprintf(&sb, "%s:down", ifName) + continue + } + fmt.Fprintf(&sb, "%s:[", ifName) + needSpace := false + for _, pfx := range s.InterfaceIPs[ifName] { + a := pfx.Addr() + if a.IsMulticast() { + continue + } + fam := "4" + if a.Is6() { + fam = "6" + } + if needSpace { + sb.WriteString(" ") + } + needSpace = true + switch { + case a.IsLoopback(): + fmt.Fprintf(&sb, "lo%s", fam) + case a.IsLinkLocalUnicast(): + fmt.Fprintf(&sb, "llu%s", fam) + default: fmt.Fprintf(&sb, "%s", pfx) - needSpace = true } - sb.WriteString("]") - } else { - fmt.Fprintf(&sb, "%s:down", ifName) } + sb.WriteString("]") } sb.WriteString("}") @@ -351,18 +326,8 @@ func (s *State) String() string { return sb.String() } -// An InterfaceFilter indicates whether EqualFiltered should use i when deciding whether two States are equal. -// ips are all the IPPrefixes associated with i. -type InterfaceFilter func(i Interface, ips []netip.Prefix) bool - -// An IPFilter indicates whether EqualFiltered should use ip when deciding whether two States are equal. -// ip is an ip address associated with some interface under consideration. -type IPFilter func(ip netip.Addr) bool - -// EqualFiltered reports whether s and s2 are equal, -// considering only interfaces in s for which filter returns true, -// and considering only IPs for those interfaces for which filterIP returns true. -func (s *State) EqualFiltered(s2 *State, useInterface InterfaceFilter, useIP IPFilter) bool { +// Equal reports whether s and s2 are exactly equal. +func (s *State) Equal(s2 *State) bool { if s == nil && s2 == nil { return true } @@ -378,19 +343,16 @@ func (s *State) EqualFiltered(s2 *State, useInterface InterfaceFilter, useIP IPF return false } for iname, i := range s.Interface { - ips := s.InterfaceIPs[iname] - if !useInterface(i, ips) { - continue - } i2, ok := s2.Interface[iname] if !ok { return false } - ips2, ok := s2.InterfaceIPs[iname] - if !ok { + if !i.Equal(i2) { return false } - if !interfacesEqual(i, i2) || !prefixesEqualFiltered(ips, ips2, useIP) { + } + for iname, vv := range s.InterfaceIPs { + if !slices.Equal(vv, s2.InterfaceIPs[iname]) { return false } } @@ -402,10 +364,9 @@ func (s *State) HasIP(ip netip.Addr) bool { if s == nil { return false } - want := netip.PrefixFrom(ip, ip.BitLen()) for _, pv := range s.InterfaceIPs { for _, p := range pv { - if p == want { + if p.Contains(ip) { return true } } @@ -413,70 +374,45 @@ func (s *State) HasIP(ip netip.Addr) bool { return false } -func interfacesEqual(a, b Interface) bool { - return a.Index == b.Index && +func (a Interface) Equal(b Interface) bool { + if (a.Interface == nil) != (b.Interface == nil) { + return false + } + if !(a.Desc == b.Desc && netAddrsEqual(a.AltAddrs, b.AltAddrs)) { + return false + } + if a.Interface != nil && !(a.Index == b.Index && a.MTU == b.MTU && a.Name == b.Name && a.Flags == b.Flags && - bytes.Equal([]byte(a.HardwareAddr), []byte(b.HardwareAddr)) -} - -func filteredIPPs(ipps []netip.Prefix, useIP IPFilter) []netip.Prefix { - // TODO: rewrite prefixesEqualFiltered to avoid making copies - x := make([]netip.Prefix, 0, len(ipps)) - for _, ipp := range ipps { - if useIP(ipp.Addr()) { - x = append(x, ipp) - } - } - return x -} - -func prefixesEqualFiltered(a, b []netip.Prefix, useIP IPFilter) bool { - return prefixesEqual(filteredIPPs(a, useIP), filteredIPPs(b, useIP)) -} - -func prefixesEqual(a, b []netip.Prefix) bool { - if len(a) != len(b) { + bytes.Equal([]byte(a.HardwareAddr), []byte(b.HardwareAddr))) { return false } - for i, v := range a { - if b[i] != v { - return false - } - } return true } -// UseInterestingInterfaces is an InterfaceFilter that reports whether i is an interesting interface. -// An interesting interface if it is (a) not owned by Tailscale and (b) routes interesting IP addresses. -// See UseInterestingIPs for the definition of an interesting IP address. -func UseInterestingInterfaces(i Interface, ips []netip.Prefix) bool { - return !isTailscaleInterface(i.Name, ips) && anyInterestingIP(ips) -} - -// UseInterestingIPs is an IPFilter that reports whether ip is an interesting IP address. -// An IP address is interesting if it is neither a loopback nor a link local unicast IP address. -func UseInterestingIPs(ip netip.Addr) bool { - return isInterestingIP(ip) -} - -// UseAllInterfaces is an InterfaceFilter that includes all interfaces. -func UseAllInterfaces(i Interface, ips []netip.Prefix) bool { return true } - -// UseAllIPs is an IPFilter that includes all IPs. -func UseAllIPs(ips netip.Addr) bool { return true } - func (s *State) HasPAC() bool { return s != nil && s.PAC != "" } // AnyInterfaceUp reports whether any interface seems like it has Internet access. func (s *State) AnyInterfaceUp() bool { - if runtime.GOOS == "js" { + if runtime.GOOS == "js" || runtime.GOOS == "tamago" { return true } return s != nil && (s.HaveV4 || s.HaveV6) } +func netAddrsEqual(a, b []net.Addr) bool { + if len(a) != len(b) { + return false + } + for i, av := range a { + if av.Network() != b[i].Network() || av.String() != b[i].String() { + return false + } + } + return true +} + func hasTailscaleIP(pfxs []netip.Prefix) bool { for _, pfx := range pfxs { if tsaddr.IsTailscaleIP(pfx.Addr()) { @@ -506,6 +442,8 @@ var getPAC func() string // GetState returns the state of all the current machine's network interfaces. // // It does not set the returned State.IsExpensive. The caller can populate that. +// +// Deprecated: use netmon.Monitor.InterfaceState instead. func GetState() (*State, error) { s := &State{ InterfaceIPs: make(map[string][]netip.Prefix), @@ -662,11 +600,23 @@ var ( v6Global1 = netip.MustParsePrefix("2000::/3") ) -// anyInterestingIP reports whether pfxs contains any IP that matches -// isInterestingIP. -func anyInterestingIP(pfxs []netip.Prefix) bool { - for _, pfx := range pfxs { - if isInterestingIP(pfx.Addr()) { +// keepInterfaceInStringSummary reports whether the named interface should be included +// in the String method's summary string. +func (s *State) keepInterfaceInStringSummary(ifName string) bool { + iface, ok := s.Interface[ifName] + if !ok || iface.Interface == nil { + return false + } + if ifName == s.DefaultRouteInterface { + return true + } + up := iface.IsUp() + for _, p := range s.InterfaceIPs[ifName] { + a := p.Addr() + if a.IsLinkLocalUnicast() || a.IsLoopback() { + continue + } + if up || a.IsGlobalUnicast() || a.IsPrivate() { return true } } @@ -675,9 +625,12 @@ func anyInterestingIP(pfxs []netip.Prefix) bool { // isInterestingIP reports whether ip is an interesting IP that we // should log in interfaces.State logging. We don't need to show -// localhost or link-local addresses. +// loopback, link-local addresses, or non-Tailscale ULA addresses. func isInterestingIP(ip netip.Addr) bool { - return !ip.IsLoopback() && !ip.IsLinkLocalUnicast() + if ip.IsLoopback() || ip.IsLinkLocalUnicast() { + return false + } + return true } var altNetInterfaces func() ([]Interface, error) diff --git a/vendor/tailscale.com/net/memnet/conn.go b/vendor/tailscale.com/net/memnet/conn.go index fb7776e613..a9e1fd3990 100644 --- a/vendor/tailscale.com/net/memnet/conn.go +++ b/vendor/tailscale.com/net/memnet/conn.go @@ -9,6 +9,10 @@ import ( "time" ) +// NetworkName is the network name returned by [net.Addr.Network] +// for [net.Conn.LocalAddr] and [net.Conn.RemoteAddr] from the [Conn] type. +const NetworkName = "mem" + // Conn is a net.Conn that can additionally have its reads and writes blocked and unblocked. type Conn interface { net.Conn @@ -45,7 +49,7 @@ func NewTCPConn(src, dst netip.AddrPort, maxBuf int) (local Conn, remote Conn) { type connAddr string -func (a connAddr) Network() string { return "mem" } +func (a connAddr) Network() string { return NetworkName } func (a connAddr) String() string { return string(a) } type connHalf struct { diff --git a/vendor/tailscale.com/net/netcheck/netcheck.go b/vendor/tailscale.com/net/netcheck/netcheck.go index b0a93454a6..a863a5a19f 100644 --- a/vendor/tailscale.com/net/netcheck/netcheck.go +++ b/vendor/tailscale.com/net/netcheck/netcheck.go @@ -27,7 +27,6 @@ import ( "tailscale.com/envknob" "tailscale.com/net/dnscache" "tailscale.com/net/interfaces" - "tailscale.com/net/netaddr" "tailscale.com/net/neterror" "tailscale.com/net/netmon" "tailscale.com/net/netns" @@ -40,7 +39,7 @@ import ( "tailscale.com/types/logger" "tailscale.com/types/nettype" "tailscale.com/types/opt" - "tailscale.com/types/ptr" + "tailscale.com/types/views" "tailscale.com/util/clientmetric" "tailscale.com/util/cmpx" "tailscale.com/util/mak" @@ -82,6 +81,7 @@ const ( defaultInitialRetransmitTime = 100 * time.Millisecond ) +// Report contains the result of a single netcheck. type Report struct { UDP bool // a UDP STUN round trip completed IPv6 bool // an IPv6 STUN round trip completed @@ -152,7 +152,12 @@ func cloneDurationMap(m map[int]time.Duration) map[int]time.Duration { return m2 } -// Client generates a netcheck Report. +// Client generates Reports describing the result of both passive and active +// network configuration probing. It provides two different modes of report, a +// full report (see MakeNextReportFull) and a more lightweight incremental +// report. The client must be provided with SendPacket in order to perform +// active probes, and must receive STUN packet replies via ReceiveSTUNPacket. +// Client can be used in a standalone fashion via the Standalone method. type Client struct { // Verbose enables verbose logging. Verbose bool @@ -164,28 +169,22 @@ type Client struct { // NetMon optionally provides a netmon.Monitor to use to get the current // (cached) network interface. // If nil, the interface will be looked up dynamically. + // TODO(bradfitz): make NetMon required. As of 2023-08-01, it basically always is + // present anyway. NetMon *netmon.Monitor // TimeNow, if non-nil, is used instead of time.Now. TimeNow func() time.Time - // GetSTUNConn4 optionally provides a func to return the - // connection to use for sending & receiving IPv4 packets. If - // nil, an ephemeral one is created as needed. - GetSTUNConn4 func() STUNConn - - // GetSTUNConn6 is like GetSTUNConn4, but for IPv6. - GetSTUNConn6 func() STUNConn + // SendPacket is required to send a packet to the specified address. For + // convenience it shares a signature with WriteToUDPAddrPort. + SendPacket func([]byte, netip.AddrPort) (int, error) // SkipExternalNetwork controls whether the client should not try // to reach things other than localhost. This is set to true // in tests to avoid probing the local LAN's router, etc. SkipExternalNetwork bool - // UDPBindAddr, if non-empty, is the address to listen on for UDP. - // It defaults to ":0". - UDPBindAddr string - // PortMapper, if non-nil, is used for portmap queries. // If nil, portmap discovery is not done. PortMapper *portmapper.Client // lazily initialized on first use @@ -208,17 +207,10 @@ type Client struct { prev map[time.Time]*Report // some previous reports last *Report // most recent report lastFull time.Time // time of last full (non-incremental) report - curState *reportState // non-nil if we're in a call to GetReportn + curState *reportState // non-nil if we're in a call to GetReport resolver *dnscache.Resolver // only set if UseDNSCache is true } -// STUNConn is the interface required by the netcheck Client when -// reusing an existing UDP connection. -type STUNConn interface { - WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) - ReadFromUDPAddrPort([]byte) (int, netip.AddrPort, error) -} - func (c *Client) enoughRegions() int { if c.testEnoughRegions > 0 { return c.testEnoughRegions @@ -278,6 +270,10 @@ func (c *Client) MakeNextReportFull() { c.nextFull = true } +// ReceiveSTUNPacket must be called when a STUN packet is received as a reply to +// packet the client sent using SendPacket. In Standalone this is performed by +// the loop started by Standalone, in normal operation in tailscaled incoming +// STUN replies are routed to this method. func (c *Client) ReceiveSTUNPacket(pkt []byte, src netip.AddrPort) { c.vlogf("received STUN packet from %s", src) @@ -524,53 +520,12 @@ func nodeMight4(n *tailcfg.DERPNode) bool { return ip.Is4() } -type packetReaderFromCloser interface { - ReadFromUDPAddrPort([]byte) (int, netip.AddrPort, error) - io.Closer -} - -// readPackets reads STUN packets from pc until there's an error or ctx is done. -// In either case, it closes pc. -func (c *Client) readPackets(ctx context.Context, pc packetReaderFromCloser) { - done := make(chan struct{}) - defer close(done) - - go func() { - select { - case <-ctx.Done(): - case <-done: - } - pc.Close() - }() - - var buf [64 << 10]byte - for { - n, addr, err := pc.ReadFromUDPAddrPort(buf[:]) - if err != nil { - if ctx.Err() != nil { - return - } - c.logf("ReadFrom: %v", err) - return - } - pkt := buf[:n] - if !stun.Is(pkt) { - continue - } - if ap := netaddr.Unmap(addr); ap.IsValid() { - c.ReceiveSTUNPacket(pkt, ap) - } - } -} - // reportState holds the state for a single invocation of Client.GetReport. type reportState struct { c *Client hairTX stun.TxID gotHairSTUN chan netip.AddrPort hairTimeout chan struct{} // closed on timeout - pc4 STUNConn - pc6 STUNConn pc4Hair nettype.PacketConn incremental bool // doing a lite, follow-up netcheck stopProbeCh chan struct{} @@ -781,13 +736,6 @@ func newReport() *Report { } } -func (c *Client) udpBindAddr() string { - if v := c.UDPBindAddr; v != "" { - return v - } - return ":0" -} - // GetReport gets a report. // // It may not be called concurrently with itself. @@ -920,42 +868,6 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap) (_ *Report, []byte("tailscale netcheck; see https://github.com/tailscale/tailscale/issues/188"), netip.AddrPortFrom(netip.MustParseAddr(documentationIP), 12345)) - if f := c.GetSTUNConn4; f != nil { - rs.pc4 = f() - } else { - u4, err := nettype.MakePacketListenerWithNetIP(netns.Listener(c.logf, nil)).ListenPacket(ctx, "udp4", c.udpBindAddr()) - if err != nil { - c.logf("udp4: %v", err) - return nil, err - } - rs.pc4 = u4 - go c.readPackets(ctx, u4) - } - - if ifState.HaveV6 { - if f := c.GetSTUNConn6; f != nil { - rs.pc6 = f() - } else { - u6, err := nettype.MakePacketListenerWithNetIP(netns.Listener(c.logf, nil)).ListenPacket(ctx, "udp6", c.udpBindAddr()) - if err != nil { - c.logf("udp6: %v", err) - } else { - rs.pc6 = u6 - go c.readPackets(ctx, u6) - } - } - - // If our interfaces.State suggested we have IPv6 support but then we - // failed to get an IPv6 sending socket (as in - // https://github.com/tailscale/tailscale/issues/7949), then change - // ifState.HaveV6 before we make a probe plan that involves sending IPv6 - // packets and thus assuming rs.pc6 is non-nil. - if rs.pc6 == nil { - ifState = ptr.To(*ifState) // shallow clone - ifState.HaveV6 = false - } - } - plan := makeProbePlan(dm, ifState, last) // If we're doing a full probe, also check for a captive portal. We @@ -1110,7 +1022,7 @@ func (c *Client) finishAndStoreReport(rs *reportState, dm *tailcfg.DERPMap) *Rep report := rs.report.Clone() rs.mu.Unlock() - c.addReportHistoryAndSetPreferredDERP(report) + c.addReportHistoryAndSetPreferredDERP(report, dm.View()) c.logConciseReport(report, dm) return report @@ -1442,9 +1354,20 @@ func (c *Client) timeNow() time.Time { return time.Now() } +const ( + // preferredDERPAbsoluteDiff specifies the minimum absolute difference + // in latencies between two DERP regions that would cause a node to + // switch its PreferredDERP ("home DERP"). This ensures that if a node + // is 5ms from two different DERP regions, it doesn't flip-flop back + // and forth between them if one region gets slightly slower (e.g. if a + // node is near region 1 @ 4ms and region 2 @ 5ms, region 1 getting + // 5ms slower would cause a flap). + preferredDERPAbsoluteDiff = 10 * time.Millisecond +) + // addReportHistoryAndSetPreferredDERP adds r to the set of recent Reports // and mutates r.PreferredDERP to contain the best recent one. -func (c *Client) addReportHistoryAndSetPreferredDERP(r *Report) { +func (c *Client) addReportHistoryAndSetPreferredDERP(r *Report, dm tailcfg.DERPMapView) { c.mu.Lock() defer c.mu.Unlock() @@ -1476,11 +1399,33 @@ func (c *Client) addReportHistoryAndSetPreferredDERP(r *Report) { } } + // Scale each region's best latency by any provided scores from the + // DERPMap, for use in comparison below. + var scores views.Map[int, float64] + if hp := dm.HomeParams(); hp.Valid() { + scores = hp.RegionScore() + } + for regionID, d := range bestRecent { + if score := scores.Get(regionID); score > 0 { + bestRecent[regionID] = time.Duration(float64(d) * score) + } + } + // Then, pick which currently-alive DERP server from the // current report has the best latency over the past maxAge. - var bestAny time.Duration - var oldRegionCurLatency time.Duration + var ( + bestAny time.Duration // global minimum + oldRegionCurLatency time.Duration // latency of old PreferredDERP + ) for regionID, d := range r.RegionLatency { + // Scale this report's latency by any scores provided by the + // server; we did this for the bestRecent map above, but we + // don't mutate the actual reports in-place (in case scores + // change), so we need to do it here as well. + if score := scores.Get(regionID); score > 0 { + d = time.Duration(float64(d) * score) + } + if regionID == prevDERP { oldRegionCurLatency = d } @@ -1491,13 +1436,27 @@ func (c *Client) addReportHistoryAndSetPreferredDERP(r *Report) { } } - // If we're changing our preferred DERP but the old one's still - // accessible and the new one's not much better, just stick with - // where we are. - if prevDERP != 0 && - r.PreferredDERP != prevDERP && - oldRegionCurLatency != 0 && - bestAny > oldRegionCurLatency/3*2 { + // If we're changing our preferred DERP, we want to add some stickiness + // to the current DERP region. We avoid changing if the old region is + // still accessible and one of the conditions below is true. + keepOld := false + changingPreferred := prevDERP != 0 && r.PreferredDERP != prevDERP + oldRegionIsAccessible := oldRegionCurLatency != 0 + if changingPreferred && oldRegionIsAccessible { + // bestAny < any other value, so oldRegionCurLatency - bestAny >= 0 + if oldRegionCurLatency-bestAny < preferredDERPAbsoluteDiff { + // The absolute value of latency difference is below + // our minimum threshold. + keepOld = true + } + if bestAny > oldRegionCurLatency/3*2 { + // Old region is about the same on a percentage basis + keepOld = true + } + } + if keepOld { + // Reset the report's PreferredDERP to be the previous value, + // which undoes any region change we made above. r.PreferredDERP = prevDERP } } @@ -1563,26 +1522,35 @@ func (rs *reportState) runProbe(ctx context.Context, dm *tailcfg.DERPMap, probe } rs.mu.Unlock() + if rs.c.SendPacket == nil { + rs.mu.Lock() + rs.report.IPv4CanSend = false + rs.report.IPv6CanSend = false + rs.mu.Unlock() + return + } + switch probe.proto { case probeIPv4: metricSTUNSend4.Add(1) - n, err := rs.pc4.WriteToUDPAddrPort(req, addr) - if n == len(req) && err == nil || neterror.TreatAsLostUDP(err) { - rs.mu.Lock() - rs.report.IPv4CanSend = true - rs.mu.Unlock() - } case probeIPv6: metricSTUNSend6.Add(1) - n, err := rs.pc6.WriteToUDPAddrPort(req, addr) - if n == len(req) && err == nil || neterror.TreatAsLostUDP(err) { - rs.mu.Lock() - rs.report.IPv6CanSend = true - rs.mu.Unlock() - } default: panic("bad probe proto " + fmt.Sprint(probe.proto)) } + + n, err := rs.c.SendPacket(req, addr) + if n == len(req) && err == nil || neterror.TreatAsLostUDP(err) { + rs.mu.Lock() + switch probe.proto { + case probeIPv4: + rs.report.IPv4CanSend = true + case probeIPv6: + rs.report.IPv6CanSend = true + } + rs.mu.Unlock() + } + c.vlogf("sent to %v", addr) } diff --git a/vendor/tailscale.com/net/netcheck/standalone.go b/vendor/tailscale.com/net/netcheck/standalone.go new file mode 100644 index 0000000000..87fbc211ea --- /dev/null +++ b/vendor/tailscale.com/net/netcheck/standalone.go @@ -0,0 +1,99 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netcheck + +import ( + "context" + "errors" + "net/netip" + + "tailscale.com/net/netaddr" + "tailscale.com/net/netns" + "tailscale.com/net/stun" + "tailscale.com/types/logger" + "tailscale.com/types/nettype" + "tailscale.com/util/multierr" +) + +// Standalone creates the necessary UDP sockets on the given bindAddr and starts +// an IO loop so that the Client can perform active probes with no further need +// for external driving of IO (no need to set/implement SendPacket, or call +// ReceiveSTUNPacket). It must be called prior to starting any reports and is +// shut down by cancellation of the provided context. If both IPv4 and IPv6 fail +// to bind, errors will be returned, if one or both protocols can bind no error +// is returned. +func (c *Client) Standalone(ctx context.Context, bindAddr string) error { + if bindAddr == "" { + bindAddr = ":0" + } + var errs []error + + u4, err := nettype.MakePacketListenerWithNetIP(netns.Listener(c.logf, nil)).ListenPacket(ctx, "udp4", bindAddr) + if err != nil { + c.logf("udp4: %v", err) + errs = append(errs, err) + } else { + go readPackets(ctx, c.logf, u4, c.ReceiveSTUNPacket) + } + + u6, err := nettype.MakePacketListenerWithNetIP(netns.Listener(c.logf, nil)).ListenPacket(ctx, "udp6", bindAddr) + if err != nil { + c.logf("udp6: %v", err) + errs = append(errs, err) + } else { + go readPackets(ctx, c.logf, u6, c.ReceiveSTUNPacket) + } + + c.SendPacket = func(pkt []byte, dst netip.AddrPort) (int, error) { + pc := u4 + if dst.Addr().Is6() { + pc = u6 + } + if pc == nil { + return 0, errors.New("no UDP socket") + } + + return pc.WriteToUDPAddrPort(pkt, dst) + } + + // If both v4 and v6 failed, report an error, otherwise let one succeed. + if len(errs) == 2 { + return multierr.New(errs...) + } + return nil +} + +// readPackets reads STUN packets from pc until there's an error or ctx is done. +// In either case, it closes pc. +func readPackets(ctx context.Context, logf logger.Logf, pc nettype.PacketConn, recv func([]byte, netip.AddrPort)) { + done := make(chan struct{}) + defer close(done) + + go func() { + select { + case <-ctx.Done(): + case <-done: + } + pc.Close() + }() + + var buf [64 << 10]byte + for { + n, addr, err := pc.ReadFromUDPAddrPort(buf[:]) + if err != nil { + if ctx.Err() != nil { + return + } + logf("ReadFrom: %v", err) + return + } + pkt := buf[:n] + if !stun.Is(pkt) { + continue + } + if ap := netaddr.Unmap(addr); ap.IsValid() { + recv(pkt, ap) + } + } +} diff --git a/vendor/tailscale.com/net/netmon/netmon.go b/vendor/tailscale.com/net/netmon/netmon.go index 4de8a47f47..b0872034ad 100644 --- a/vendor/tailscale.com/net/netmon/netmon.go +++ b/vendor/tailscale.com/net/netmon/netmon.go @@ -16,6 +16,7 @@ import ( "tailscale.com/net/interfaces" "tailscale.com/types/logger" + "tailscale.com/util/clientmetric" "tailscale.com/util/set" ) @@ -51,10 +52,14 @@ type osMon interface { // Monitor represents a monitoring instance. type Monitor struct { logf logger.Logf - om osMon // nil means not supported on this platform - change chan struct{} + om osMon // nil means not supported on this platform + change chan bool // send false to wake poller, true to also force ChangeDeltas be sent stop chan struct{} // closed on Stop + // Things that must be set early, before use, + // and not change at runtime. + tsIfName string // tailscale interface name, if known/set ("tailscale0", "utun3", ...) + mu sync.Mutex // guards all following fields cbs set.HandleSet[ChangeFunc] ruleDelCB set.HandleSet[RuleDeleteCallback] @@ -71,9 +76,40 @@ type Monitor struct { } // ChangeFunc is a callback function registered with Monitor that's called when the -// network changed. The changed parameter is whether the network changed -// enough for State to have changed since the last callback. -type ChangeFunc func(changed bool, state *interfaces.State) +// network changed. +type ChangeFunc func(*ChangeDelta) + +// ChangeDelta describes the difference between two network states. +type ChangeDelta struct { + // Monitor is the network monitor that sent this delta. + Monitor *Monitor + + // Old is the old interface state, if known. + // It's nil if the old state is unknown. + // Do not mutate it. + Old *interfaces.State + + // New is the new network state. + // It is always non-nil. + // Do not mutate it. + New *interfaces.State + + // Major is our legacy boolean of whether the network changed in some major + // way. + // + // Deprecated: do not remove. As of 2023-08-23 we're in a renewed effort to + // remove it and ask specific qustions of ChangeDelta instead. Look at Old + // and New (or add methods to ChangeDelta) instead of using Major. + Major bool + + // TimeJumped is whether there was a big jump in wall time since the last + // time we checked. This is a hint that a mobile sleeping device might have + // come out of sleep. + TimeJumped bool + + // TODO(bradfitz): add some lazy cached fields here as needed with methods + // on *ChangeDelta to let callers ask specific questions +} // New instantiates and starts a monitoring instance. // The returned monitor is inactive until it's started by the Start method. @@ -82,7 +118,7 @@ func New(logf logger.Logf) (*Monitor, error) { logf = logger.WithPrefix(logf, "monitor: ") m := &Monitor{ logf: logf, - change: make(chan struct{}, 1), + change: make(chan bool, 1), stop: make(chan struct{}), lastWall: wallTime(), } @@ -117,6 +153,15 @@ func (m *Monitor) interfaceStateUncached() (*interfaces.State, error) { return interfaces.GetState() } +// SetTailscaleInterfaceName sets the name of the Tailscale interface. For +// example, "tailscale0", "tun0", "utun3", etc. +// +// This must be called only early in tailscaled startup before the monitor is +// used. +func (m *Monitor) SetTailscaleInterfaceName(ifName string) { + m.tsIfName = ifName +} + // GatewayAndSelfIP returns the current network's default gateway, and // the machine's default IP for that gateway. // @@ -129,8 +174,14 @@ func (m *Monitor) GatewayAndSelfIP() (gw, myIP netip.Addr, ok bool) { return m.gw, m.gwSelfIP, true } gw, myIP, ok = interfaces.LikelyHomeRouterIP() + changed := false if ok { - m.gw, m.gwSelfIP, m.gwValid = gw, myIP, true + changed = m.gw != gw || m.gwSelfIP != myIP + m.gw, m.gwSelfIP = gw, myIP + m.gwValid = true + } + if changed { + m.logf("gateway and self IP changed: gw=%v self=%v", m.gw, m.gwSelfIP) } return gw, myIP, ok } @@ -225,7 +276,7 @@ func (m *Monitor) Close() error { // period (under a fraction of a second). func (m *Monitor) InjectEvent() { select { - case m.change <- struct{}{}: + case m.change <- true: default: // Another change signal is already // buffered. Debounce will wake up soon @@ -233,6 +284,18 @@ func (m *Monitor) InjectEvent() { } } +// Poll forces the monitor to pretend there was a network +// change and re-check the state of the network. +// +// This is like InjectEvent but only fires ChangeFunc callbacks +// if the network state differed at all. +func (m *Monitor) Poll() { + select { + case m.change <- false: + default: + } +} + func (m *Monitor) stopped() bool { select { case <-m.stop: @@ -264,7 +327,7 @@ func (m *Monitor) pump() { if msg.ignore() { continue } - m.InjectEvent() + m.Poll() } } @@ -280,7 +343,11 @@ func (m *Monitor) notifyRuleDeleted(rdm ipRuleDeletedMessage) { // considered when checking for network state changes. // The ips parameter should be the IPs of the provided interface. func (m *Monitor) isInterestingInterface(i interfaces.Interface, ips []netip.Prefix) bool { - return m.om.IsInterestingInterface(i.Name) && interfaces.UseInterestingInterfaces(i, ips) + if !m.om.IsInterestingInterface(i.Name) { + return false + } + + return true } // debounce calls the callback function with a delay between events @@ -288,42 +355,17 @@ func (m *Monitor) isInterestingInterface(i interfaces.Interface, ips []netip.Pre func (m *Monitor) debounce() { defer m.goroutines.Done() for { + var forceCallbacks bool select { case <-m.stop: return - case <-m.change: + case forceCallbacks = <-m.change: } - if curState, err := m.interfaceStateUncached(); err != nil { + if newState, err := m.interfaceStateUncached(); err != nil { m.logf("interfaces.State: %v", err) } else { - m.mu.Lock() - - oldState := m.ifState - changed := !curState.EqualFiltered(oldState, m.isInterestingInterface, interfaces.UseInterestingIPs) - if changed { - m.gwValid = false - m.ifState = curState - - if s1, s2 := oldState.String(), curState.String(); s1 == s2 { - m.logf("[unexpected] network state changed, but stringification didn't: %v", s1) - m.logf("[unexpected] old: %s", jsonSummary(oldState)) - m.logf("[unexpected] new: %s", jsonSummary(curState)) - } - } - // See if we have a queued or new time jump signal. - if shouldMonitorTimeJump && m.checkWallTimeAdvanceLocked() { - m.resetTimeJumpedLocked() - if !changed { - // Only log if it wasn't an interesting change. - m.logf("time jumped (probably wake from sleep); synthesizing major change event") - changed = true - } - } - for _, cb := range m.cbs { - go cb(changed, m.ifState) - } - m.mu.Unlock() + m.handlePotentialChange(newState, forceCallbacks) } select { @@ -334,6 +376,140 @@ func (m *Monitor) debounce() { } } +var ( + metricChangeEq = clientmetric.NewCounter("netmon_link_change_eq") + metricChange = clientmetric.NewCounter("netmon_link_change") + metricChangeTimeJump = clientmetric.NewCounter("netmon_link_change_timejump") + metricChangeMajor = clientmetric.NewCounter("netmon_link_change_major") +) + +// handlePotentialChange considers whether newState is different enough to wake +// up callers and updates the monitor's state if so. +// +// If forceCallbacks is true, they're always notified. +func (m *Monitor) handlePotentialChange(newState *interfaces.State, forceCallbacks bool) { + m.mu.Lock() + defer m.mu.Unlock() + oldState := m.ifState + timeJumped := shouldMonitorTimeJump && m.checkWallTimeAdvanceLocked() + if !timeJumped && !forceCallbacks && oldState.Equal(newState) { + // Exactly equal. Nothing to do. + metricChangeEq.Add(1) + return + } + + delta := &ChangeDelta{ + Monitor: m, + Old: oldState, + New: newState, + TimeJumped: timeJumped, + } + + delta.Major = m.IsMajorChangeFrom(oldState, newState) + if delta.Major { + m.gwValid = false + m.ifState = newState + + if s1, s2 := oldState.String(), delta.New.String(); s1 == s2 { + m.logf("[unexpected] network state changed, but stringification didn't: %v", s1) + m.logf("[unexpected] old: %s", jsonSummary(oldState)) + m.logf("[unexpected] new: %s", jsonSummary(newState)) + } + } + // See if we have a queued or new time jump signal. + if timeJumped { + m.resetTimeJumpedLocked() + if !delta.Major { + // Only log if it wasn't an interesting change. + m.logf("time jumped (probably wake from sleep); synthesizing major change event") + delta.Major = true + } + } + metricChange.Add(1) + if delta.Major { + metricChangeMajor.Add(1) + } + if delta.TimeJumped { + metricChangeTimeJump.Add(1) + } + for _, cb := range m.cbs { + go cb(delta) + } +} + +// IsMajorChangeFrom reports whether the transition from s1 to s2 is +// a "major" change, where major roughly means it's worth tearing down +// a bunch of connections and rebinding. +// +// TODO(bradiftz): tigten this definition. +func (m *Monitor) IsMajorChangeFrom(s1, s2 *interfaces.State) bool { + if s1 == nil && s2 == nil { + return false + } + if s1 == nil || s2 == nil { + return true + } + if s1.HaveV6 != s2.HaveV6 || + s1.HaveV4 != s2.HaveV4 || + s1.IsExpensive != s2.IsExpensive || + s1.DefaultRouteInterface != s2.DefaultRouteInterface || + s1.HTTPProxy != s2.HTTPProxy || + s1.PAC != s2.PAC { + return true + } + for iname, i := range s1.Interface { + if iname == m.tsIfName { + // Ignore changes in the Tailscale interface itself. + continue + } + ips := s1.InterfaceIPs[iname] + if !m.isInterestingInterface(i, ips) { + continue + } + i2, ok := s2.Interface[iname] + if !ok { + return true + } + ips2, ok := s2.InterfaceIPs[iname] + if !ok { + return true + } + if !i.Equal(i2) || !prefixesMajorEqual(ips, ips2) { + return true + } + } + return false +} + +// prefixesMajorEqual reports whether a and b are equal after ignoring +// boring things like link-local, loopback, and multicast addresses. +func prefixesMajorEqual(a, b []netip.Prefix) bool { + // trim returns a subslice of p with link local unicast, + // loopback, and multicast prefixes removed from the front. + trim := func(p []netip.Prefix) []netip.Prefix { + for len(p) > 0 { + a := p[0].Addr() + if a.IsLinkLocalUnicast() || a.IsLoopback() || a.IsMulticast() { + p = p[1:] + continue + } + break + } + return p + } + for { + a = trim(a) + b = trim(b) + if len(a) == 0 || len(b) == 0 { + return len(a) == 0 && len(b) == 0 + } + if a[0] != b[0] { + return false + } + a, b = a[1:], b[1:] + } +} + func jsonSummary(x any) any { j, err := json.Marshal(x) if err != nil { diff --git a/vendor/tailscale.com/net/netmon/netmon_linux.go b/vendor/tailscale.com/net/netmon/netmon_linux.go index 9065b99532..dd23dd3426 100644 --- a/vendor/tailscale.com/net/netmon/netmon_linux.go +++ b/vendor/tailscale.com/net/netmon/netmon_linux.go @@ -100,7 +100,7 @@ func (c *nlConn) Receive() (message, error) { typ = "RTM_DELADDR" } - // label attributes are seemingly only populated for ipv4 addresses in the wild. + // label attributes are seemingly only populated for IPv4 addresses in the wild. label := rmsg.Attributes.Label if label == "" { itf, err := net.InterfaceByIndex(int(rmsg.Index)) diff --git a/vendor/tailscale.com/net/netmon/polling.go b/vendor/tailscale.com/net/netmon/polling.go index 9332bdde98..ce1618ed6c 100644 --- a/vendor/tailscale.com/net/netmon/polling.go +++ b/vendor/tailscale.com/net/netmon/polling.go @@ -13,7 +13,6 @@ import ( "sync" "time" - "tailscale.com/net/interfaces" "tailscale.com/types/logger" ) @@ -68,22 +67,18 @@ func (pm *pollingMon) Receive() (message, error) { // so this can go very slowly there, to save battery. // https://github.com/tailscale/tailscale/issues/1427 d = 10 * time.Minute - } - if pm.isCloudRun() { + } else if pm.isCloudRun() { // Cloud Run routes never change at runtime. the containers are killed within // 15 minutes by default, set the interval long enough to be effectively infinite. pm.logf("monitor polling: Cloud Run detected, reduce polling interval to 24h") d = 24 * time.Hour } - ticker := time.NewTicker(d) - defer ticker.Stop() - base := pm.m.InterfaceState() + timer := time.NewTimer(d) + defer timer.Stop() for { - if cur, err := pm.m.interfaceStateUncached(); err == nil && !cur.EqualFiltered(base, interfaces.UseInterestingInterfaces, interfaces.UseInterestingIPs) { - return unspecifiedMessage{}, nil - } select { - case <-ticker.C: + case <-timer.C: + return unspecifiedMessage{}, nil case <-pm.stop: return nil, errors.New("stopped") } diff --git a/vendor/tailscale.com/net/netns/netns_darwin.go b/vendor/tailscale.com/net/netns/netns_darwin.go index b32a744b7e..b9395f7347 100644 --- a/vendor/tailscale.com/net/netns/netns_darwin.go +++ b/vendor/tailscale.com/net/netns/netns_darwin.go @@ -20,6 +20,7 @@ import ( "tailscale.com/envknob" "tailscale.com/net/interfaces" "tailscale.com/net/netmon" + "tailscale.com/net/tsaddr" "tailscale.com/types/logger" ) @@ -110,7 +111,7 @@ func getInterfaceIndex(logf logger.Logf, netMon *netmon.Monitor, address string) // Verify that we didn't just choose the Tailscale interface; // if so, we fall back to binding from the default. - _, tsif, err2 := interfaces.Tailscale() + tsif, err2 := tailscaleInterface() if err2 == nil && tsif != nil && tsif.Index == idx { logf("[unexpected] netns: interfaceIndexFor returned Tailscale interface") return defaultIdx() @@ -119,6 +120,34 @@ func getInterfaceIndex(logf logger.Logf, netMon *netmon.Monitor, address string) return idx, err } +// tailscaleInterface returns the current machine's Tailscale interface, if any. +// If none is found, (nil, nil) is returned. +// A non-nil error is only returned on a problem listing the system interfaces. +func tailscaleInterface() (*net.Interface, error) { + ifs, err := net.Interfaces() + if err != nil { + return nil, err + } + for _, iface := range ifs { + if !strings.HasPrefix(iface.Name, "utun") { + continue + } + addrs, err := iface.Addrs() + if err != nil { + continue + } + for _, a := range addrs { + if ipnet, ok := a.(*net.IPNet); ok { + nip, ok := netip.AddrFromSlice(ipnet.IP) + if ok && tsaddr.IsTailscaleIP(nip.Unmap()) { + return &iface, nil + } + } + } + } + return nil, nil +} + // interfaceIndexFor returns the interface index that we should bind to in // order to send traffic to the provided address. func interfaceIndexFor(addr netip.Addr, canRecurse bool) (int, error) { diff --git a/vendor/tailscale.com/net/netns/netns_linux.go b/vendor/tailscale.com/net/netns/netns_linux.go index 5d09d7d192..bac14e9d77 100644 --- a/vendor/tailscale.com/net/netns/netns_linux.go +++ b/vendor/tailscale.com/net/netns/netns_linux.go @@ -17,16 +17,9 @@ import ( "tailscale.com/net/interfaces" "tailscale.com/net/netmon" "tailscale.com/types/logger" + "tailscale.com/util/linuxfw" ) -// tailscaleBypassMark is the mark indicating that packets originating -// from a socket should bypass Tailscale-managed routes during routing -// table lookups. -// -// Keep this in sync with tailscaleBypassMark in -// wgengine/router/router_linux.go. -const tailscaleBypassMark = 0x80000 - // socketMarkWorksOnce is the sync.Once & cached value for useSocketMark. var socketMarkWorksOnce struct { sync.Once @@ -119,7 +112,7 @@ func controlC(network, address string, c syscall.RawConn) error { } func setBypassMark(fd uintptr) error { - if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, tailscaleBypassMark); err != nil { + if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, linuxfw.TailscaleBypassMarkNum); err != nil { return fmt.Errorf("setting SO_MARK bypass: %w", err) } return nil diff --git a/vendor/tailscale.com/net/netutil/ip_forward.go b/vendor/tailscale.com/net/netutil/ip_forward.go index f8c6a90fcd..afcea4e5ad 100644 --- a/vendor/tailscale.com/net/netutil/ip_forward.go +++ b/vendor/tailscale.com/net/netutil/ip_forward.go @@ -51,7 +51,7 @@ func protocolsRequiredForForwarding(routes []netip.Prefix, state *interfaces.Sta // CheckIPForwarding reports whether IP forwarding is enabled correctly // for subnet routing and exit node functionality on any interface. -// The state param can be nil, in which case interfaces.GetState is used. +// The state param must not be nil. // The routes should only be advertised routes, and should not contain the // nodes Tailscale IPs. // It returns an error if it is unable to determine if IP forwarding is enabled. @@ -65,14 +65,10 @@ func CheckIPForwarding(routes []netip.Prefix, state *interfaces.State) (warn, er } return nil, nil } - const kbLink = "\nSee https://tailscale.com/s/ip-forwarding" if state == nil { - var err error - state, err = interfaces.GetState() - if err != nil { - return nil, err - } + return nil, fmt.Errorf("Couldn't check system's IP forwarding configuration; no link state") } + const kbLink = "\nSee https://tailscale.com/s/ip-forwarding" wantV4, wantV6 := protocolsRequiredForForwarding(routes, state) if !wantV4 && !wantV6 { return nil, nil @@ -212,9 +208,16 @@ func ipForwardingEnabledLinux(p protocol, iface string) (bool, error) { } return false, err } - on, err := strconv.ParseBool(string(bytes.TrimSpace(bs))) + + val, err := strconv.ParseInt(string(bytes.TrimSpace(bs)), 10, 32) if err != nil { return false, fmt.Errorf("couldn't parse %s: %w", k, err) } + // 0 = disabled, 1 = enabled, 2 = enabled (but uncommon) + // https://github.com/tailscale/tailscale/issues/8375 + if val < 0 || val > 2 { + return false, fmt.Errorf("unexpected value %d for %s", val, k) + } + on := val == 1 || val == 2 return on, nil } diff --git a/vendor/tailscale.com/net/netutil/routes.go b/vendor/tailscale.com/net/netutil/routes.go new file mode 100644 index 0000000000..83f29bf3a1 --- /dev/null +++ b/vendor/tailscale.com/net/netutil/routes.go @@ -0,0 +1,93 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netutil + +import ( + "encoding/binary" + "fmt" + "net/netip" + "sort" + "strings" + + "tailscale.com/net/tsaddr" +) + +var ( + ipv4default = netip.MustParsePrefix("0.0.0.0/0") + ipv6default = netip.MustParsePrefix("::/0") +) + +func validateViaPrefix(ipp netip.Prefix) error { + if !tsaddr.IsViaPrefix(ipp) { + return fmt.Errorf("%v is not a 4-in-6 prefix", ipp) + } + if ipp.Bits() < (128 - 32) { + return fmt.Errorf("%v 4-in-6 prefix must be at least a /%v", ipp, 128-32) + } + a := ipp.Addr().As16() + // The first 64 bits of a are the via prefix. + // The next 32 bits are the "site ID". + // The last 32 bits are the IPv4. + // For now, we reserve the top 3 bytes of the site ID, + // and only allow users to use site IDs 0-255. + siteID := binary.BigEndian.Uint32(a[8:12]) + if siteID > 0xFF { + return fmt.Errorf("route %v contains invalid site ID %08x; must be 0xff or less", ipp, siteID) + } + return nil +} + +// CalcAdvertiseRoutes calculates the requested routes to be advertised by a node. +// advertiseRoutes is the user-provided, comma-separated list of routes (IP addresses or CIDR prefixes) to advertise. +// advertiseDefaultRoute indicates whether the node should act as an exit node and advertise default routes. +func CalcAdvertiseRoutes(advertiseRoutes string, advertiseDefaultRoute bool) ([]netip.Prefix, error) { + routeMap := map[netip.Prefix]bool{} + if advertiseRoutes != "" { + var default4, default6 bool + advroutes := strings.Split(advertiseRoutes, ",") + for _, s := range advroutes { + ipp, err := netip.ParsePrefix(s) + if err != nil { + return nil, fmt.Errorf("%q is not a valid IP address or CIDR prefix", s) + } + if ipp != ipp.Masked() { + return nil, fmt.Errorf("%s has non-address bits set; expected %s", ipp, ipp.Masked()) + } + if tsaddr.IsViaPrefix(ipp) { + if err := validateViaPrefix(ipp); err != nil { + return nil, err + } + } + if ipp == ipv4default { + default4 = true + } else if ipp == ipv6default { + default6 = true + } + routeMap[ipp] = true + } + if default4 && !default6 { + return nil, fmt.Errorf("%s advertised without its IPv6 counterpart, please also advertise %s", ipv4default, ipv6default) + } else if default6 && !default4 { + return nil, fmt.Errorf("%s advertised without its IPv4 counterpart, please also advertise %s", ipv6default, ipv4default) + } + } + if advertiseDefaultRoute { + routeMap[netip.MustParsePrefix("0.0.0.0/0")] = true + routeMap[netip.MustParsePrefix("::/0")] = true + } + if len(routeMap) == 0 { + return nil, nil + } + routes := make([]netip.Prefix, 0, len(routeMap)) + for r := range routeMap { + routes = append(routes, r) + } + sort.Slice(routes, func(i, j int) bool { + if routes[i].Bits() != routes[j].Bits() { + return routes[i].Bits() < routes[j].Bits() + } + return routes[i].Addr().Less(routes[j].Addr()) + }) + return routes, nil +} diff --git a/vendor/tailscale.com/net/ping/ping.go b/vendor/tailscale.com/net/ping/ping.go index 170d87fb94..01f3dcf2c4 100644 --- a/vendor/tailscale.com/net/ping/ping.go +++ b/vendor/tailscale.com/net/ping/ping.go @@ -303,7 +303,7 @@ func (p *Pinger) Send(ctx context.Context, dest net.Addr, data []byte) (time.Dur m := icmp.Message{ Type: icmpType, - Code: icmpType.Protocol(), + Code: 0, Body: &icmp.Echo{ ID: int(p.id), Seq: int(seq), diff --git a/vendor/tailscale.com/net/portmapper/portmapper.go b/vendor/tailscale.com/net/portmapper/portmapper.go index 2e500b5553..3fde487cd0 100644 --- a/vendor/tailscale.com/net/portmapper/portmapper.go +++ b/vendor/tailscale.com/net/portmapper/portmapper.go @@ -18,6 +18,7 @@ import ( "time" "go4.org/mem" + "tailscale.com/control/controlknobs" "tailscale.com/net/interfaces" "tailscale.com/net/netaddr" "tailscale.com/net/neterror" @@ -27,6 +28,7 @@ import ( "tailscale.com/types/logger" "tailscale.com/types/nettype" "tailscale.com/util/clientmetric" + "tailscale.com/util/mak" ) // DebugKnobs contains debug configuration that can be provided when creating a @@ -36,6 +38,11 @@ type DebugKnobs struct { // to its logger. VerboseLogs bool + // LogHTTP tells the Client to print the raw HTTP logs (from UPnP) to + // its logger. This is useful when debugging buggy UPnP + // implementations. + LogHTTP bool + // Disable* disables a specific service from mapping. DisableUPnP bool DisablePMP bool @@ -61,6 +68,7 @@ const trustServiceStillAvailableDuration = 10 * time.Minute type Client struct { logf logger.Logf netMon *netmon.Monitor // optional; nil means interfaces will be looked up on-demand + controlKnobs *controlknobs.Knobs ipAndGateway func() (gw, ip netip.Addr, ok bool) onChange func() // or nil debug DebugKnobs @@ -161,15 +169,19 @@ func (m *pmpMapping) Release(ctx context.Context) { // The debug argument allows configuring the behaviour of the portmapper for // debugging; if nil, a sensible set of defaults will be used. // -// The optional onChange argument specifies a func to run in a new -// goroutine whenever the port mapping status has changed. If nil, -// it doesn't make a callback. -func NewClient(logf logger.Logf, netMon *netmon.Monitor, debug *DebugKnobs, onChange func()) *Client { +// The controlKnobs, if non-nil, specifies the control knobs from the control +// plane that might disable portmapping. +// +// The optional onChange argument specifies a func to run in a new goroutine +// whenever the port mapping status has changed. If nil, it doesn't make a +// callback. +func NewClient(logf logger.Logf, netMon *netmon.Monitor, debug *DebugKnobs, controlKnobs *controlknobs.Knobs, onChange func()) *Client { ret := &Client{ logf: logf, netMon: netMon, ipAndGateway: interfaces.LikelyHomeRouterIP, onChange: onChange, + controlKnobs: controlKnobs, } if debug != nil { ret.debug = *debug @@ -1013,3 +1025,31 @@ var ( // we received a UPnP response with a new meta. metricUPnPUpdatedMeta = clientmetric.NewCounter("portmap_upnp_updated_meta") ) + +// UPnP error metric that's keyed by code; lazily registered on first read +var ( + metricUPnPErrorsByCodeMu sync.Mutex + metricUPnPErrorsByCode map[int]*clientmetric.Metric +) + +func getUPnPErrorsMetric(code int) *clientmetric.Metric { + metricUPnPErrorsByCodeMu.Lock() + defer metricUPnPErrorsByCodeMu.Unlock() + mm := metricUPnPErrorsByCode[code] + if mm != nil { + return mm + } + + // Metric names cannot contain a hyphen, so we handle negative numbers + // by prefixing the name with a "minus_". + var codeStr string + if code < 0 { + codeStr = fmt.Sprintf("portmap_upnp_errors_with_code_minus_%d", -code) + } else { + codeStr = fmt.Sprintf("portmap_upnp_errors_with_code_%d", code) + } + + mm = clientmetric.NewCounter(codeStr) + mak.Set(&metricUPnPErrorsByCode, code, mm) + return mm +} diff --git a/vendor/tailscale.com/net/portmapper/upnp.go b/vendor/tailscale.com/net/portmapper/upnp.go index 6a525c54ff..31650a516e 100644 --- a/vendor/tailscale.com/net/portmapper/upnp.go +++ b/vendor/tailscale.com/net/portmapper/upnp.go @@ -11,18 +11,22 @@ import ( "bufio" "bytes" "context" + "encoding/xml" "fmt" + "io" "math/rand" "net" "net/http" "net/netip" "net/url" "strings" + "sync/atomic" "time" "github.com/tailscale/goupnp" "github.com/tailscale/goupnp/dcps/internetgateway2" - "tailscale.com/control/controlknobs" + "github.com/tailscale/goupnp/soap" + "tailscale.com/envknob" "tailscale.com/net/netns" "tailscale.com/types/logger" ) @@ -106,11 +110,13 @@ type upnpClient interface { // It is not used for anything other than labelling. const tsPortMappingDesc = "tailscale-portmap" -// addAnyPortMapping abstracts over different UPnP client connections, calling the available -// AddAnyPortMapping call if available for WAN IP connection v2, otherwise defaulting to the old -// behavior of calling AddPortMapping with port = 0 to specify a wildcard port. -// It returns the new external port (which may not be identical to the external port specified), -// or an error. +// addAnyPortMapping abstracts over different UPnP client connections, calling +// the available AddAnyPortMapping call if available for WAN IP connection v2, +// otherwise picking either the previous port (if one is present) or a random +// port and trying to obtain a mapping using AddPortMapping. +// +// It returns the new external port (which may not be identical to the external +// port specified), or an error. // // TODO(bradfitz): also returned the actual lease duration obtained. and check it regularly. func addAnyPortMapping( @@ -121,6 +127,31 @@ func addAnyPortMapping( internalClient string, leaseDuration time.Duration, ) (newPort uint16, err error) { + // Some devices don't let clients add a port mapping for privileged + // ports (ports below 1024). Additionally, per section 2.3.18 of the + // UPnP spec, regarding the ExternalPort field: + // + // If this value is specified as a wildcard (i.e. 0), connection + // request on all external ports (that are not otherwise mapped) + // will be forwarded to InternalClient. In the wildcard case, the + // value(s) of InternalPort on InternalClient are ignored by the IGD + // for those connections that are forwarded to InternalClient. + // Obviously only one such entry can exist in the NAT at any time + // and conflicts are handled with a “first write wins” behavior. + // + // We obviously do not want to open all ports on the user's device to + // the internet, so we want to do this prior to calling either + // AddAnyPortMapping or AddPortMapping. + // + // Pick an external port that's greater than 1024 by getting a random + // number in [0, 65535 - 1024] and then adding 1024 to it, shifting the + // range to [1024, 65535]. + if externalPort < 1024 { + externalPort = uint16(rand.Intn(65535-1024) + 1024) + } + + // First off, try using AddAnyPortMapping; if there's a conflict, the + // router will pick another port and return it. if upnp, ok := upnp.(*internetgateway2.WANIPConnection2); ok { return upnp.AddAnyPortMapping( ctx, @@ -135,15 +166,8 @@ func addAnyPortMapping( ) } - // Some devices don't let clients add a port mapping for privileged - // ports (ports below 1024). - // - // Pick an external port that's greater than 1024 by getting a random - // number in [0, 65535 - 1024] and then adding 1024 to it, shifting the - // range to [1024, 65535]. - if externalPort < 1024 { - externalPort = uint16(rand.Intn(65535-1024) + 1024) - } + // Fall back to using AddPortMapping, which requests a mapping to/from + // a specific external port. err = upnp.AddPortMapping( ctx, "", @@ -170,7 +194,7 @@ func addAnyPortMapping( // The provided ctx is not retained in the returned upnpClient, but // its associated HTTP client is (if set via goupnp.WithHTTPClient). func getUPnPClient(ctx context.Context, logf logger.Logf, debug DebugKnobs, gw netip.Addr, meta uPnPDiscoResponse) (client upnpClient, err error) { - if controlknobs.DisableUPnP() || debug.DisableUPnP { + if debug.DisableUPnP { return nil, nil } @@ -241,10 +265,17 @@ func (c *Client) upnpHTTPClientLocked() *http.Client { IdleConnTimeout: 2 * time.Second, // LAN is cheap }, } + if c.debug.LogHTTP { + c.uPnPHTTPClient = requestLogger(c.logf, c.uPnPHTTPClient) + } } return c.uPnPHTTPClient } +var ( + disableUPnpEnv = envknob.RegisterBool("TS_DISABLE_UPNP") +) + // getUPnPPortMapping attempts to create a port-mapping over the UPnP protocol. On success, // it will return the externally exposed IP and port. Otherwise, it will return a zeroed IP and // port and an error. @@ -254,7 +285,7 @@ func (c *Client) getUPnPPortMapping( internal netip.AddrPort, prevPort uint16, ) (external netip.AddrPort, ok bool) { - if controlknobs.DisableUPnP() || c.debug.DisableUPnP { + if disableUPnpEnv() || c.debug.DisableUPnP || (c.controlKnobs != nil && c.controlKnobs.DisableUPnP.Load()) { return netip.AddrPort{}, false } @@ -287,6 +318,7 @@ func (c *Client) getUPnPPortMapping( return netip.AddrPort{}, false } + // Start by trying to make a temporary lease with a duration. var newPort uint16 newPort, err = addAnyPortMapping( ctx, @@ -294,14 +326,42 @@ func (c *Client) getUPnPPortMapping( prevPort, internal.Port(), internal.Addr().String(), - time.Second*pmpMapLifetimeSec, + pmpMapLifetimeSec*time.Second, ) if c.debug.VerboseLogs { c.logf("addAnyPortMapping: %v, err=%q", newPort, err) } + + // If this is an error and the code is + // "OnlyPermanentLeasesSupported", then we retry with no lease + // duration; see the following issue for details: + // https://github.com/tailscale/tailscale/issues/9343 + if err != nil { + code, ok := getUPnPErrorCode(err) + if ok { + getUPnPErrorsMetric(code).Add(1) + } + + // From the UPnP spec: http://upnp.org/specs/gw/UPnP-gw-WANIPConnection-v2-Service.pdf + // 725: OnlyPermanentLeasesSupported + if ok && code == 725 { + newPort, err = addAnyPortMapping( + ctx, + client, + prevPort, + internal.Port(), + internal.Addr().String(), + 0, // permanent + ) + if c.debug.VerboseLogs { + c.logf("addAnyPortMapping: 725 retry %v, err=%q", newPort, err) + } + } + } if err != nil { return netip.AddrPort{}, false } + // TODO cache this ip somewhere? extIP, err := client.GetExternalIPAddress(ctx) if c.debug.VerboseLogs { @@ -317,6 +377,10 @@ func (c *Client) getUPnPPortMapping( } upnp.external = netip.AddrPortFrom(externalIP, newPort) + + // NOTE: this time might not technically be accurate if we created a + // permanent lease above, but we should still re-check the presence of + // the lease on a regular basis so we use it anyway. d := time.Duration(pmpMapLifetimeSec) * time.Second upnp.goodUntil = now.Add(d) upnp.renewAfter = now.Add(d / 2) @@ -328,6 +392,29 @@ func (c *Client) getUPnPPortMapping( return upnp.external, true } +// getUPnPErrorCode returns the UPnP error code from the given response, if the +// error is a SOAP error in the proper format, and a boolean indicating whether +// the provided error was actually a UPnP error. +func getUPnPErrorCode(err error) (int, bool) { + soapErr, ok := err.(*soap.SOAPFaultError) + if !ok { + return 0, false + } + + var upnpErr struct { + XMLName xml.Name + Code int `xml:"errorCode"` + Description string `xml:"errorDescription"` + } + if err := xml.Unmarshal([]byte(soapErr.Detail.Raw), &upnpErr); err != nil { + return 0, false + } + if upnpErr.XMLName.Local != "UPnPError" { + return 0, false + } + return upnpErr.Code, true +} + type uPnPDiscoResponse struct { Location string // Server describes what version the UPnP is, such as MiniUPnPd/2.x.x @@ -349,3 +436,60 @@ func parseUPnPDiscoResponse(body []byte) (uPnPDiscoResponse, error) { r.USN = res.Header.Get("Usn") return r, nil } + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (r roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return r(req) +} + +func requestLogger(logf logger.Logf, client *http.Client) *http.Client { + // Clone the HTTP client, and override the Transport to log to the + // provided logger. + ret := *client + oldTransport := ret.Transport + + var requestCounter atomic.Uint64 + loggingTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + ctr := requestCounter.Add(1) + + // Read the body and re-set it. + var ( + body []byte + err error + ) + if req.Body != nil { + body, err = io.ReadAll(req.Body) + req.Body.Close() + if err != nil { + return nil, err + } + req.Body = io.NopCloser(bytes.NewReader(body)) + } + + logf("request[%d]: %s %q body=%q", ctr, req.Method, req.URL, body) + + resp, err := oldTransport.RoundTrip(req) + if err != nil { + logf("response[%d]: err=%v", err) + return nil, err + } + + // Read the response body + if resp.Body != nil { + body, err = io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + logf("response[%d]: %d bodyErr=%v", resp.StatusCode, err) + return nil, err + } + resp.Body = io.NopCloser(bytes.NewReader(body)) + } + + logf("response[%d]: %d body=%q", ctr, resp.StatusCode, body) + return resp, nil + }) + ret.Transport = loggingTransport + + return &ret +} diff --git a/vendor/tailscale.com/net/sockstats/sockstats_tsgo.go b/vendor/tailscale.com/net/sockstats/sockstats_tsgo.go index 37edddddf0..d94f279ad5 100644 --- a/vendor/tailscale.com/net/sockstats/sockstats_tsgo.go +++ b/vendor/tailscale.com/net/sockstats/sockstats_tsgo.go @@ -266,25 +266,29 @@ func setNetMon(netMon *netmon.Monitor) { sockStats.usedInterfaces[ifIndex] = 1 } - netMon.RegisterChangeCallback(func(changed bool, state *interfaces.State) { - if changed { - if ifName := state.DefaultRouteInterface; ifName != "" { - ifIndex := state.Interface[ifName].Index - sockStats.mu.Lock() - defer sockStats.mu.Unlock() - // Ignore changes to unknown interfaces -- it would require - // updating the tx/rxBytesByInterface maps and thus - // additional locking for every read/write. Most of the time - // the set of interfaces is static. - if _, ok := sockStats.knownInterfaces[ifIndex]; ok { - sockStats.currentInterface.Store(uint32(ifIndex)) - sockStats.usedInterfaces[ifIndex] = 1 - sockStats.currentInterfaceCellular.Store(isLikelyCellularInterface(ifName)) - } else { - sockStats.currentInterface.Store(0) - sockStats.currentInterfaceCellular.Store(false) - } - } + netMon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) { + if !delta.Major { + return + } + state := delta.New + ifName := state.DefaultRouteInterface + if ifName == "" { + return + } + ifIndex := state.Interface[ifName].Index + sockStats.mu.Lock() + defer sockStats.mu.Unlock() + // Ignore changes to unknown interfaces -- it would require + // updating the tx/rxBytesByInterface maps and thus + // additional locking for every read/write. Most of the time + // the set of interfaces is static. + if _, ok := sockStats.knownInterfaces[ifIndex]; ok { + sockStats.currentInterface.Store(uint32(ifIndex)) + sockStats.usedInterfaces[ifIndex] = 1 + sockStats.currentInterfaceCellular.Store(isLikelyCellularInterface(ifName)) + } else { + sockStats.currentInterface.Store(0) + sockStats.currentInterfaceCellular.Store(false) } }) } diff --git a/vendor/tailscale.com/net/tcpinfo/tcpinfo.go b/vendor/tailscale.com/net/tcpinfo/tcpinfo.go new file mode 100644 index 0000000000..a757add9f8 --- /dev/null +++ b/vendor/tailscale.com/net/tcpinfo/tcpinfo.go @@ -0,0 +1,51 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tcpinfo provides platform-agnostic accessors to information about a +// TCP connection (e.g. RTT, MSS, etc.). +package tcpinfo + +import ( + "errors" + "net" + "time" +) + +var ( + ErrNotTCP = errors.New("tcpinfo: not a TCP conn") + ErrUnimplemented = errors.New("tcpinfo: unimplemented") +) + +// RTT returns the RTT for the given net.Conn. +// +// If the net.Conn is not a *net.TCPConn and cannot be unwrapped into one, then +// ErrNotTCP will be returned. If retrieving the RTT is not supported on the +// current platform, ErrUnimplemented will be returned. +func RTT(conn net.Conn) (time.Duration, error) { + tcpConn, err := unwrap(conn) + if err != nil { + return 0, err + } + + return rttImpl(tcpConn) +} + +// netConner is implemented by crypto/tls.Conn to unwrap into an underlying +// net.Conn. +type netConner interface { + NetConn() net.Conn +} + +// unwrap attempts to unwrap a net.Conn into an underlying *net.TCPConn +func unwrap(nc net.Conn) (*net.TCPConn, error) { + for { + switch v := nc.(type) { + case *net.TCPConn: + return v, nil + case netConner: + nc = v.NetConn() + default: + return nil, ErrNotTCP + } + } +} diff --git a/vendor/tailscale.com/net/tcpinfo/tcpinfo_darwin.go b/vendor/tailscale.com/net/tcpinfo/tcpinfo_darwin.go new file mode 100644 index 0000000000..53fa22fbf5 --- /dev/null +++ b/vendor/tailscale.com/net/tcpinfo/tcpinfo_darwin.go @@ -0,0 +1,33 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tcpinfo + +import ( + "net" + "time" + + "golang.org/x/sys/unix" +) + +func rttImpl(conn *net.TCPConn) (time.Duration, error) { + rawConn, err := conn.SyscallConn() + if err != nil { + return 0, err + } + + var ( + tcpInfo *unix.TCPConnectionInfo + sysErr error + ) + err = rawConn.Control(func(fd uintptr) { + tcpInfo, sysErr = unix.GetsockoptTCPConnectionInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_CONNECTION_INFO) + }) + if err != nil { + return 0, err + } else if sysErr != nil { + return 0, sysErr + } + + return time.Duration(tcpInfo.Rttcur) * time.Millisecond, nil +} diff --git a/vendor/tailscale.com/net/tcpinfo/tcpinfo_linux.go b/vendor/tailscale.com/net/tcpinfo/tcpinfo_linux.go new file mode 100644 index 0000000000..885d462c95 --- /dev/null +++ b/vendor/tailscale.com/net/tcpinfo/tcpinfo_linux.go @@ -0,0 +1,33 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tcpinfo + +import ( + "net" + "time" + + "golang.org/x/sys/unix" +) + +func rttImpl(conn *net.TCPConn) (time.Duration, error) { + rawConn, err := conn.SyscallConn() + if err != nil { + return 0, err + } + + var ( + tcpInfo *unix.TCPInfo + sysErr error + ) + err = rawConn.Control(func(fd uintptr) { + tcpInfo, sysErr = unix.GetsockoptTCPInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_INFO) + }) + if err != nil { + return 0, err + } else if sysErr != nil { + return 0, sysErr + } + + return time.Duration(tcpInfo.Rtt) * time.Microsecond, nil +} diff --git a/vendor/tailscale.com/net/tcpinfo/tcpinfo_other.go b/vendor/tailscale.com/net/tcpinfo/tcpinfo_other.go new file mode 100644 index 0000000000..be45523aeb --- /dev/null +++ b/vendor/tailscale.com/net/tcpinfo/tcpinfo_other.go @@ -0,0 +1,15 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux && !darwin + +package tcpinfo + +import ( + "net" + "time" +) + +func rttImpl(conn *net.TCPConn) (time.Duration, error) { + return 0, ErrUnimplemented +} diff --git a/vendor/tailscale.com/net/tlsdial/tlsdial.go b/vendor/tailscale.com/net/tlsdial/tlsdial.go index d571d38a6b..2c9be4e1cb 100644 --- a/vendor/tailscale.com/net/tlsdial/tlsdial.go +++ b/vendor/tailscale.com/net/tlsdial/tlsdial.go @@ -1,22 +1,24 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -// Package tlsdial originally existed to set up a tls.Config for x509 -// validation, using a memory-optimized path for iOS, but then we -// moved that to the tailscale/go tree instead, so now this package -// does very little. But for now we keep it as a unified point where -// we might want to add shared policy on outgoing TLS connections from -// the 3 places in the client that connect to Tailscale (logs, -// control, DERP). +// Package tlsdial generates tls.Config values and does x509 validation of +// certs. It bakes in the LetsEncrypt roots so even if the user's machine +// doesn't have TLS roots, we can at least connect to Tailscale's LetsEncrypt +// services. It's the unified point where we can add shared policy on outgoing +// TLS connections from the three places in the client that connect to Tailscale +// (logs, control, DERP). package tlsdial import ( "bytes" + "context" "crypto/tls" "crypto/x509" "errors" "fmt" "log" + "net" + "net/http" "os" "sync" "sync/atomic" @@ -192,6 +194,22 @@ func SetConfigExpectedCert(c *tls.Config, certDNSName string) { } } +// NewTransport returns a new HTTP transport that verifies TLS certs using this +// package, including its baked-in LetsEncrypt fallback roots. +func NewTransport() *http.Transport { + return &http.Transport{ + DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + var d tls.Dialer + d.Config = Config(host, nil) + return d.DialContext(ctx, network, addr) + }, + } +} + /* letsEncryptX1 is the LetsEncrypt X1 root: diff --git a/vendor/tailscale.com/net/tsaddr/tsaddr.go b/vendor/tailscale.com/net/tsaddr/tsaddr.go index 34259b6907..088ff35e12 100644 --- a/vendor/tailscale.com/net/tsaddr/tsaddr.go +++ b/vendor/tailscale.com/net/tsaddr/tsaddr.go @@ -8,10 +8,12 @@ import ( "encoding/binary" "errors" "net/netip" + "slices" "sync" - "golang.org/x/exp/slices" + "go4.org/netipx" "tailscale.com/net/netaddr" + "tailscale.com/types/views" ) // ChromeOSVMRange returns the subset of the CGNAT IPv4 range used by @@ -159,6 +161,11 @@ type oncePrefix struct { v netip.Prefix } +// FalseContainsIPFunc is shorthand for NewContainsIPFunc(views.Slice[netip.Prefix]{}). +func FalseContainsIPFunc() func(ip netip.Addr) bool { + return func(ip netip.Addr) bool { return false } +} + // NewContainsIPFunc returns a func that reports whether ip is in addrs. // // It's optimized for the cases of addrs being empty and addrs @@ -166,20 +173,17 @@ type oncePrefix struct { // one IPv6 address). // // Otherwise the implementation is somewhat slow. -func NewContainsIPFunc(addrs []netip.Prefix) func(ip netip.Addr) bool { +func NewContainsIPFunc(addrs views.Slice[netip.Prefix]) func(ip netip.Addr) bool { // Specialize the three common cases: no address, just IPv4 // (or just IPv6), and both IPv4 and IPv6. - if len(addrs) == 0 { + if addrs.Len() == 0 { return func(netip.Addr) bool { return false } } // If any addr is more than a single IP, then just do the slow // linear thing until // https://github.com/inetaf/netaddr/issues/139 is done. - for _, a := range addrs { - if a.IsSingleIP() { - continue - } - acopy := append([]netip.Prefix(nil), addrs...) + if views.SliceContainsFunc(addrs, func(p netip.Prefix) bool { return !p.IsSingleIP() }) { + acopy := addrs.AsSlice() return func(ip netip.Addr) bool { for _, a := range acopy { if a.Contains(ip) { @@ -190,18 +194,18 @@ func NewContainsIPFunc(addrs []netip.Prefix) func(ip netip.Addr) bool { } } // Fast paths for 1 and 2 IPs: - if len(addrs) == 1 { - a := addrs[0] + if addrs.Len() == 1 { + a := addrs.At(0) return func(ip netip.Addr) bool { return ip == a.Addr() } } - if len(addrs) == 2 { - a, b := addrs[0], addrs[1] + if addrs.Len() == 2 { + a, b := addrs.At(0), addrs.At(1) return func(ip netip.Addr) bool { return ip == a.Addr() || ip == b.Addr() } } // General case: m := map[netip.Addr]bool{} - for _, a := range addrs { - m[a.Addr()] = true + for i := range addrs.LenIter() { + m[addrs.At(i).Addr()] = true } return func(ip netip.Addr) bool { return m[ip] } } @@ -224,9 +228,10 @@ func PrefixIs6(p netip.Prefix) bool { return p.Addr().Is6() } // ContainsExitRoutes reports whether rr contains both the IPv4 and // IPv6 /0 route. -func ContainsExitRoutes(rr []netip.Prefix) bool { +func ContainsExitRoutes(rr views.Slice[netip.Prefix]) bool { var v4, v6 bool - for _, r := range rr { + for i := range rr.LenIter() { + r := rr.At(i) if r == allIPv4 { v4 = true } else if r == allIPv6 { @@ -236,6 +241,17 @@ func ContainsExitRoutes(rr []netip.Prefix) bool { return v4 && v6 } +// ContainsNonExitSubnetRoutes reports whether v contains Subnet +// Routes other than ExitNode Routes. +func ContainsNonExitSubnetRoutes(rr views.Slice[netip.Prefix]) bool { + for i := range rr.LenIter() { + if rr.At(i).Bits() != 0 { + return true + } + } + return false +} + var ( allIPv4 = netip.MustParsePrefix("0.0.0.0/0") allIPv6 = netip.MustParsePrefix("::/0") @@ -252,20 +268,15 @@ func ExitRoutes() []netip.Prefix { return []netip.Prefix{allIPv4, allIPv6} } // SortPrefixes sorts the prefixes in place. func SortPrefixes(p []netip.Prefix) { - slices.SortFunc(p, func(ri, rj netip.Prefix) bool { - if ri.Addr() == rj.Addr() { - return ri.Bits() < rj.Bits() - } - return ri.Addr().Less(rj.Addr()) - }) + slices.SortFunc(p, netipx.ComparePrefix) } // FilterPrefixes returns a new slice, not aliasing in, containing elements of // in that match f. -func FilterPrefixesCopy(in []netip.Prefix, f func(netip.Prefix) bool) []netip.Prefix { +func FilterPrefixesCopy(in views.Slice[netip.Prefix], f func(netip.Prefix) bool) []netip.Prefix { var out []netip.Prefix - for _, v := range in { - if f(v) { + for i := range in.LenIter() { + if v := in.At(i); f(v) { out = append(out, v) } } diff --git a/vendor/tailscale.com/net/tsdial/dnsmap.go b/vendor/tailscale.com/net/tsdial/dnsmap.go index 52b8e6d817..a549b52039 100644 --- a/vendor/tailscale.com/net/tsdial/dnsmap.go +++ b/vendor/tailscale.com/net/tsdial/dnsmap.go @@ -35,30 +35,32 @@ func dnsMapFromNetworkMap(nm *netmap.NetworkMap) dnsMap { ret := make(dnsMap) suffix := nm.MagicDNSSuffix() have4 := false - if nm.Name != "" && len(nm.Addresses) > 0 { - ip := nm.Addresses[0].Addr() + addrs := nm.GetAddresses() + if nm.Name != "" && addrs.Len() > 0 { + ip := addrs.At(0).Addr() ret[canonMapKey(nm.Name)] = ip if dnsname.HasSuffix(nm.Name, suffix) { ret[canonMapKey(dnsname.TrimSuffix(nm.Name, suffix))] = ip } - for _, a := range nm.Addresses { - if a.Addr().Is4() { + for i := range addrs.LenIter() { + if addrs.At(i).Addr().Is4() { have4 = true } } } for _, p := range nm.Peers { - if p.Name == "" { + if p.Name() == "" { continue } - for _, a := range p.Addresses { + for i := range p.Addresses().LenIter() { + a := p.Addresses().At(i) ip := a.Addr() if ip.Is4() && !have4 { continue } - ret[canonMapKey(p.Name)] = ip - if dnsname.HasSuffix(p.Name, suffix) { - ret[canonMapKey(dnsname.TrimSuffix(p.Name, suffix))] = ip + ret[canonMapKey(p.Name())] = ip + if dnsname.HasSuffix(p.Name(), suffix) { + ret[canonMapKey(dnsname.TrimSuffix(p.Name(), suffix))] = ip } break } diff --git a/vendor/tailscale.com/net/tsdial/tsdial.go b/vendor/tailscale.com/net/tsdial/tsdial.go index 54e91ca13c..e901369a3f 100644 --- a/vendor/tailscale.com/net/tsdial/tsdial.go +++ b/vendor/tailscale.com/net/tsdial/tsdial.go @@ -18,12 +18,12 @@ import ( "time" "tailscale.com/net/dnscache" - "tailscale.com/net/interfaces" "tailscale.com/net/netknob" "tailscale.com/net/netmon" "tailscale.com/net/netns" "tailscale.com/types/logger" "tailscale.com/types/netmap" + "tailscale.com/util/clientmetric" "tailscale.com/util/mak" ) @@ -139,18 +139,52 @@ func (d *Dialer) SetNetMon(netMon *netmon.Monitor) { d.netMonUnregister = d.netMon.RegisterChangeCallback(d.linkChanged) } -func (d *Dialer) linkChanged(major bool, state *interfaces.State) { - if !major { - return - } +var ( + metricLinkChangeConnClosed = clientmetric.NewCounter("tsdial_linkchange_closes") +) + +func (d *Dialer) linkChanged(delta *netmon.ChangeDelta) { d.mu.Lock() defer d.mu.Unlock() + var anyClosed bool for id, c := range d.activeSysConns { - go c.Close() - delete(d.activeSysConns, id) + if changeAffectsConn(delta, c) { + anyClosed = true + d.logf("tsdial: closing system connection %v->%v due to link change", c.LocalAddr(), c.RemoteAddr()) + go c.Close() + delete(d.activeSysConns, id) + } + } + if anyClosed { + metricLinkChangeConnClosed.Add(1) } } +// changeAffectsConn reports whether the network change delta affects +// the provided connection. +func changeAffectsConn(delta *netmon.ChangeDelta, conn net.Conn) bool { + la, _ := conn.LocalAddr().(*net.TCPAddr) + ra, _ := conn.RemoteAddr().(*net.TCPAddr) + if la == nil || ra == nil { + return false // not TCP + } + lip, rip := la.AddrPort().Addr(), ra.AddrPort().Addr() + + if delta.Old == nil { + return false + } + if delta.Old.DefaultRouteInterface != delta.New.DefaultRouteInterface || + delta.Old.HTTPProxy != delta.New.HTTPProxy { + return true + } + if !delta.New.HasIP(lip) && delta.Old.HasIP(lip) { + // Our interface with this source IP went away. + return true + } + _ = rip // TODO(bradfitz): use the remote IP? + return false +} + func (d *Dialer) closeSysConn(id int) { d.mu.Lock() defer d.mu.Unlock() @@ -203,6 +237,8 @@ func (d *Dialer) SetNetMap(nm *netmap.NetworkMap) { d.dns = m } +// userDialResolve resolves addr as if a user initiating the dial. (e.g. from a +// SOCKS or HTTP outbound proxy) func (d *Dialer) userDialResolve(ctx context.Context, network, addr string) (netip.AddrPort, error) { d.mu.Lock() dns := d.dns @@ -262,6 +298,12 @@ func ipNetOfNetwork(n string) string { return "ip" } +func (d *Dialer) logf(format string, args ...any) { + if d.Logf != nil { + d.Logf(format, args...) + } +} + // SystemDial connects to the provided network address without going over // Tailscale. It prefers going over the default interface and closes existing // connections if the default interface changes. It is used to connect to @@ -275,11 +317,7 @@ func (d *Dialer) SystemDial(ctx context.Context, network, addr string) (net.Conn } d.netnsDialerOnce.Do(func() { - logf := d.Logf - if logf == nil { - logf = logger.Discard - } - d.netnsDialer = netns.NewDialer(logf, d.netMon) + d.netnsDialer = netns.NewDialer(d.logf, d.netMon) }) c, err := d.netnsDialer.DialContext(ctx, network, addr) if err != nil { @@ -298,8 +336,8 @@ func (d *Dialer) SystemDial(ctx context.Context, network, addr string) (net.Conn }, nil } -// UserDial connects to the provided network address as if a user were initiating the dial. -// (e.g. from a SOCKS or HTTP outbound proxy) +// UserDial connects to the provided network address as if a user were +// initiating the dial. (e.g. from a SOCKS or HTTP outbound proxy) func (d *Dialer) UserDial(ctx context.Context, network, addr string) (net.Conn, error) { ipp, err := d.userDialResolve(ctx, network, addr) if err != nil { diff --git a/vendor/tailscale.com/net/tstun/mtu.go b/vendor/tailscale.com/net/tstun/mtu.go index 2307d47f96..fe8eab3e89 100644 --- a/vendor/tailscale.com/net/tstun/mtu.go +++ b/vendor/tailscale.com/net/tstun/mtu.go @@ -5,8 +5,8 @@ package tstun import "tailscale.com/envknob" const ( - maxMTU uint32 = 65536 - defaultMTU uint32 = 1280 + maxMTU = 65536 + defaultMTU = 1280 ) // DefaultMTU returns either the constant default MTU of 1280, or the value set @@ -21,13 +21,8 @@ func DefaultMTU() uint32 { // 1280 is the smallest MTU allowed for IPv6, which is a sensible // "probably works everywhere" setting until we develop proper PMTU // discovery. - tunMTU := defaultMTU if mtu, ok := envknob.LookupUintSized("TS_DEBUG_MTU", 10, 32); ok { - mtu := uint32(mtu) - if mtu > maxMTU { - mtu = maxMTU - } - tunMTU = mtu + return min(uint32(mtu), maxMTU) } - return tunMTU + return defaultMTU } diff --git a/vendor/tailscale.com/net/tstun/tstun_plan9.go b/vendor/tailscale.com/net/tstun/tstun_plan9.go new file mode 100644 index 0000000000..4472a7a5d7 --- /dev/null +++ b/vendor/tailscale.com/net/tstun/tstun_plan9.go @@ -0,0 +1,17 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstun + +import ( + "github.com/tailscale/wireguard-go/tun" + "tailscale.com/types/logger" +) + +func New(logf logger.Logf, tunName string) (tun.Device, string, error) { + panic("not implemented") +} + +func Diagnose(logf logger.Logf, tunName string, err error) { + panic("not implemented") +} diff --git a/vendor/tailscale.com/net/tstun/tun.go b/vendor/tailscale.com/net/tstun/tun.go index b31ffa7ca6..de0db6d1ea 100644 --- a/vendor/tailscale.com/net/tstun/tun.go +++ b/vendor/tailscale.com/net/tstun/tun.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !wasm +//go:build !wasm && !plan9 && !tamago // Package tun creates a tuntap device, working around OS-specific // quirks if necessary. diff --git a/vendor/tailscale.com/net/tstun/wrap.go b/vendor/tailscale.com/net/tstun/wrap.go index 74bf54134c..ab0a4a3db9 100644 --- a/vendor/tailscale.com/net/tstun/wrap.go +++ b/vendor/tailscale.com/net/tstun/wrap.go @@ -12,6 +12,7 @@ import ( "net/netip" "os" "reflect" + "slices" "strings" "sync" "sync/atomic" @@ -20,7 +21,6 @@ import ( "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun" "go4.org/mem" - "golang.org/x/exp/slices" "gvisor.dev/gvisor/pkg/tcpip/stack" "tailscale.com/disco" "tailscale.com/net/connstats" @@ -35,6 +35,7 @@ import ( "tailscale.com/types/views" "tailscale.com/util/clientmetric" "tailscale.com/util/mak" + "tailscale.com/util/set" "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/wgcfg" @@ -97,7 +98,7 @@ type Wrapper struct { // timeNow, if non-nil, will be used to obtain the current time. timeNow func() time.Time - // natV4Config stores the current NAT configuration. + // natV4Config stores the current IPv4 NAT configuration. natV4Config atomic.Pointer[natV4Config] // vectorBuffer stores the oldest unconsumed packet vector from tdev. It is @@ -544,6 +545,44 @@ type natV4Config struct { dstAddrToPeerKeyMapper *table.RoutingTable } +func (c *natV4Config) String() string { + if c == nil { + return "" + } + var b strings.Builder + b.WriteString("natV4Config{") + fmt.Fprintf(&b, "nativeAddr: %v, ", c.nativeAddr) + fmt.Fprint(&b, "listenAddrs: [") + + i := 0 + c.listenAddrs.Range(func(k netip.Addr, _ struct{}) bool { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(k.String()) + i++ + return true + }) + count := map[netip.Addr]int{} + c.dstMasqAddrs.Range(func(_ key.NodePublic, v netip.Addr) bool { + count[v]++ + return true + }) + + i = 0 + b.WriteString("], dstMasqAddrs: [") + for k, v := range count { + if i > 0 { + b.WriteString(", ") + } + fmt.Fprintf(&b, "%v: %v peers", k, v) + i++ + } + b.WriteString("]}") + + return b.String() +} + // mapDstIP returns the destination IP to use for a packet to dst. // If dst is not one of the listen addresses, it is returned as-is, // otherwise the native address is returned. @@ -576,9 +615,9 @@ func (c *natV4Config) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr { return oldSrc } -// natConfigFromWireGuardConfig generates a natV4Config from nm. +// natV4ConfigFromWGConfig generates a natV4Config from nm. // If v4 NAT is not required, it returns nil. -func natConfigFromWGConfig(wcfg *wgcfg.Config) *natV4Config { +func natV4ConfigFromWGConfig(wcfg *wgcfg.Config) *natV4Config { if wcfg == nil { return nil } @@ -589,7 +628,7 @@ func natConfigFromWGConfig(wcfg *wgcfg.Config) *natV4Config { var ( rt table.RoutingTableBuilder dstMasqAddrs map[key.NodePublic]netip.Addr - listenAddrs map[netip.Addr]struct{} + listenAddrs set.Set[netip.Addr] ) // When using an exit node that requires masquerading, we need to @@ -631,10 +670,10 @@ func natConfigFromWGConfig(wcfg *wgcfg.Config) *natV4Config { // SetNetMap is called when a new NetworkMap is received. // It currently (2023-03-01) only updates the IPv4 NAT configuration. func (t *Wrapper) SetWGConfig(wcfg *wgcfg.Config) { - cfg := natConfigFromWGConfig(wcfg) + cfg := natV4ConfigFromWGConfig(wcfg) old := t.natV4Config.Swap(cfg) if !reflect.DeepEqual(old, cfg) { - t.logf("nat config: %+v", cfg) + t.logf("nat config: %v", cfg) } } diff --git a/vendor/tailscale.com/net/wsconn/wsconn.go b/vendor/tailscale.com/net/wsconn/wsconn.go index 697b66ddde..2d708ac531 100644 --- a/vendor/tailscale.com/net/wsconn/wsconn.go +++ b/vendor/tailscale.com/net/wsconn/wsconn.go @@ -48,10 +48,18 @@ import ( // // A received StatusNormalClosure or StatusGoingAway close frame will be translated to // io.EOF when reading. -func NetConn(ctx context.Context, c *websocket.Conn, msgType websocket.MessageType) net.Conn { +// +// The given remoteAddr will be the value of the returned conn's +// RemoteAddr().String(). For best compatibility with consumers of +// conns, the string should be an ip:port if available, but in the +// absence of that it can be any string that describes the remote +// endpoint, or the empty string to makes RemoteAddr() return a place +// holder value. +func NetConn(ctx context.Context, c *websocket.Conn, msgType websocket.MessageType, remoteAddr string) net.Conn { nc := &netConn{ - c: c, - msgType: msgType, + c: c, + msgType: msgType, + remoteAddr: remoteAddr, } var writeCancel context.CancelFunc @@ -82,8 +90,9 @@ func NetConn(ctx context.Context, c *websocket.Conn, msgType websocket.MessageTy } type netConn struct { - c *websocket.Conn - msgType websocket.MessageType + c *websocket.Conn + msgType websocket.MessageType + remoteAddr string writeTimer *time.Timer writeContext context.Context @@ -167,6 +176,7 @@ func (c *netConn) Read(p []byte) (int, error) { } type websocketAddr struct { + addr string } func (a websocketAddr) Network() string { @@ -174,15 +184,18 @@ func (a websocketAddr) Network() string { } func (a websocketAddr) String() string { + if a.addr != "" { + return a.addr + } return "websocket/unknown-addr" } func (c *netConn) RemoteAddr() net.Addr { - return websocketAddr{} + return websocketAddr{c.remoteAddr} } func (c *netConn) LocalAddr() net.Addr { - return websocketAddr{} + return websocketAddr{""} } func (c *netConn) SetDeadline(t time.Time) error { diff --git a/vendor/tailscale.com/paths/paths.go b/vendor/tailscale.com/paths/paths.go index 343692f203..28c3be02a9 100644 --- a/vendor/tailscale.com/paths/paths.go +++ b/vendor/tailscale.com/paths/paths.go @@ -27,6 +27,9 @@ func DefaultTailscaledSocket() string { if runtime.GOOS == "darwin" { return "/var/run/tailscaled.socket" } + if runtime.GOOS == "plan9" { + return "/srv/tailscaled.sock" + } switch distro.Get() { case distro.Synology: if distro.DSMVersion() == 6 { @@ -45,7 +48,14 @@ func DefaultTailscaledSocket() string { return "tailscaled.sock" } -var stateFileFunc func() string +// Overridden in init by OS-specific files. +var ( + stateFileFunc func() string + + // ensureStateDirPerms applies a restrictive ACL/chmod + // to the provided directory. + ensureStateDirPerms = func(string) error { return nil } +) // DefaultTailscaledStateFile returns the default path to the // tailscaled state file, or the empty string if there's no reasonable @@ -67,6 +77,16 @@ func MkStateDir(dirPath string) error { if err := os.MkdirAll(dirPath, 0700); err != nil { return err } - return ensureStateDirPerms(dirPath) } + +// LegacyStateFilePath returns the legacy path to the state file when +// it was stored under the current user's %LocalAppData%. +// +// It is only called on Windows. +func LegacyStateFilePath() string { + if runtime.GOOS == "windows" { + return filepath.Join(os.Getenv("LocalAppData"), "Tailscale", "server-state.conf") + } + return "" +} diff --git a/vendor/tailscale.com/paths/paths_unix.go b/vendor/tailscale.com/paths/paths_unix.go index a8e22e7b63..fb081cdf43 100644 --- a/vendor/tailscale.com/paths/paths_unix.go +++ b/vendor/tailscale.com/paths/paths_unix.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !windows && !js && !wasip1 +//go:build !windows && !wasm && !plan9 && !tamago package paths @@ -17,6 +17,7 @@ import ( func init() { stateFileFunc = stateFileUnix + ensureStateDirPerms = ensureStateDirPermsUnix } func statePath() string { @@ -65,7 +66,7 @@ func xdgDataHome() string { return filepath.Join(os.Getenv("HOME"), ".local/share") } -func ensureStateDirPerms(dir string) error { +func ensureStateDirPermsUnix(dir string) error { if filepath.Base(dir) != "tailscale" { return nil } @@ -83,8 +84,3 @@ func ensureStateDirPerms(dir string) error { } return os.Chmod(dir, perm) } - -// LegacyStateFilePath is not applicable to UNIX; it is just stubbed out. -func LegacyStateFilePath() string { - return "" -} diff --git a/vendor/tailscale.com/paths/paths_wasm.go b/vendor/tailscale.com/paths/paths_wasm.go deleted file mode 100644 index 81e9f15409..0000000000 --- a/vendor/tailscale.com/paths/paths_wasm.go +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package paths - -func ensureStateDirPerms(dirPath string) error { - return nil -} - -func LegacyStateFilePath() string { return "" } diff --git a/vendor/tailscale.com/paths/paths_windows.go b/vendor/tailscale.com/paths/paths_windows.go index aa31e6f6ed..4705400655 100644 --- a/vendor/tailscale.com/paths/paths_windows.go +++ b/vendor/tailscale.com/paths/paths_windows.go @@ -12,7 +12,11 @@ import ( "tailscale.com/util/winutil" ) -// ensureStateDirPerms applies a restrictive ACL to the directory specified by dirPath. +func init() { + ensureStateDirPerms = ensureStateDirPermsWindows +} + +// ensureStateDirPermsWindows applies a restrictive ACL to the directory specified by dirPath. // It sets the following security attributes on the directory: // Owner: The user for the current process; // Primary Group: The primary group for the current process; @@ -26,7 +30,7 @@ import ( // // However, any directories and/or files created within this // directory *do* inherit the ACL that we are setting. -func ensureStateDirPerms(dirPath string) error { +func ensureStateDirPermsWindows(dirPath string) error { fi, err := os.Stat(dirPath) if err != nil { return err @@ -94,9 +98,3 @@ func ensureStateDirPerms(dirPath string) error { return windows.SetNamedSecurityInfo(dirPath, windows.SE_FILE_OBJECT, flags, sids.User, sids.PrimaryGroup, dacl, nil) } - -// LegacyStateFilePath returns the legacy path to the state file when it was stored under the -// current user's %LocalAppData%. -func LegacyStateFilePath() string { - return filepath.Join(os.Getenv("LocalAppData"), "Tailscale", "server-state.conf") -} diff --git a/vendor/tailscale.com/portlist/poller.go b/vendor/tailscale.com/portlist/poller.go index d1f5b2ab0c..423bad3be3 100644 --- a/vendor/tailscale.com/portlist/poller.go +++ b/vendor/tailscale.com/portlist/poller.go @@ -10,10 +10,10 @@ import ( "errors" "fmt" "runtime" + "slices" "sync" "time" - "golang.org/x/exp/slices" "tailscale.com/envknob" ) diff --git a/vendor/tailscale.com/proxymap/proxymap.go b/vendor/tailscale.com/proxymap/proxymap.go new file mode 100644 index 0000000000..a1c1bb898f --- /dev/null +++ b/vendor/tailscale.com/proxymap/proxymap.go @@ -0,0 +1,72 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package proxymap contains a mapping table for ephemeral localhost ports used +// by tailscaled on behalf of remote Tailscale IPs for proxied connections. +package proxymap + +import ( + "net/netip" + "sync" + "time" + + "tailscale.com/util/mak" +) + +// Mapper tracks which localhost ip:ports correspond to which remote Tailscale +// IPs for connections proxied by tailscaled. +// +// This is then used (via the WhoIsIPPort method) by localhost applications to +// ask tailscaled (via the LocalAPI WhoIs method) the Tailscale identity that a +// given localhost:port corresponds to. +type Mapper struct { + mu sync.Mutex + m map[netip.AddrPort]netip.Addr +} + +// RegisterIPPortIdentity registers a given node (identified by its +// Tailscale IP) as temporarily having the given IP:port for whois lookups. +// The IP:port is generally a localhost IP and an ephemeral port, used +// while proxying connections to localhost when tailscaled is running +// in netstack mode. +func (m *Mapper) RegisterIPPortIdentity(ipport netip.AddrPort, tsIP netip.Addr) { + m.mu.Lock() + defer m.mu.Unlock() + mak.Set(&m.m, ipport, tsIP) +} + +// UnregisterIPPortIdentity removes a temporary IP:port registration +// made previously by RegisterIPPortIdentity. +func (m *Mapper) UnregisterIPPortIdentity(ipport netip.AddrPort) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.m, ipport) +} + +var whoIsSleeps = [...]time.Duration{ + 0, + 10 * time.Millisecond, + 20 * time.Millisecond, + 50 * time.Millisecond, + 100 * time.Millisecond, +} + +// WhoIsIPPort looks up an IP:port in the temporary registrations, +// and returns a matching Tailscale IP, if it exists. +func (m *Mapper) WhoIsIPPort(ipport netip.AddrPort) (tsIP netip.Addr, ok bool) { + // We currently have a registration race, + // https://github.com/tailscale/tailscale/issues/1616, + // so loop a few times for now waiting for the registration + // to appear. + // TODO(bradfitz,namansood): remove this once #1616 is fixed. + for _, d := range whoIsSleeps { + time.Sleep(d) + m.mu.Lock() + tsIP, ok = m.m[ipport] + m.mu.Unlock() + if ok { + return tsIP, true + } + } + return tsIP, false +} diff --git a/vendor/tailscale.com/safesocket/safesocket_darwin.go b/vendor/tailscale.com/safesocket/safesocket_darwin.go index f725708903..36fc7c4382 100644 --- a/vendor/tailscale.com/safesocket/safesocket_darwin.go +++ b/vendor/tailscale.com/safesocket/safesocket_darwin.go @@ -8,12 +8,14 @@ import ( "bytes" "errors" "fmt" + "net" "os" "os/exec" "path/filepath" "strconv" "strings" "sync" + "time" ) func init() { @@ -46,6 +48,17 @@ func localTCPPortAndTokenMacsys() (port int, token string, err error) { if auth == "" { return 0, "", errors.New("empty auth token in sameuserproof file") } + + // The above files exist forever after the first run of + // /Applications/Tailscale.app, so check we can connect to avoid returning a + // port nothing is listening on. Connect to "127.0.0.1" rather than + // "localhost" due to #7851. + conn, err := net.DialTimeout("tcp", "127.0.0.1:"+portStr, time.Second) + if err != nil { + return 0, "", err + } + conn.Close() + return port, auth, nil } diff --git a/vendor/tailscale.com/safesocket/safesocket_plan9.go b/vendor/tailscale.com/safesocket/safesocket_plan9.go new file mode 100644 index 0000000000..4459633948 --- /dev/null +++ b/vendor/tailscale.com/safesocket/safesocket_plan9.go @@ -0,0 +1,124 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build plan9 + +package safesocket + +import ( + "fmt" + "net" + "os" + "syscall" + "time" + + "golang.org/x/sys/plan9" +) + +// Plan 9's devsrv srv(3) is a server registry and +// it is conventionally bound to "/srv" in the default +// namespace. It is "a one level directory for holding +// already open channels to services". Post one end of +// a pipe to "/srv/tailscale.sock" and use the other +// end for communication with a requestor. Plan 9 pipes +// are bidirectional. + +type plan9SrvAddr string + +func (sl plan9SrvAddr) Network() string { + return "/srv" +} + +func (sl plan9SrvAddr) String() string { + return string(sl) +} + +// There is no net.FileListener for Plan 9 at this time +type plan9SrvListener struct { + name string + srvf *os.File + file *os.File +} + +func (sl *plan9SrvListener) Accept() (net.Conn, error) { + // sl.file is the server end of the pipe that's + // connected to /srv/tailscale.sock + return plan9FileConn{name: sl.name, file: sl.file}, nil +} + +func (sl *plan9SrvListener) Close() error { + sl.file.Close() + return sl.srvf.Close() +} + +func (sl *plan9SrvListener) Addr() net.Addr { + return plan9SrvAddr(sl.name) +} + +type plan9FileConn struct { + name string + file *os.File +} + +func (fc plan9FileConn) Read(b []byte) (n int, err error) { + return fc.file.Read(b) +} +func (fc plan9FileConn) Write(b []byte) (n int, err error) { + return fc.file.Write(b) +} +func (fc plan9FileConn) Close() error { + return fc.file.Close() +} +func (fc plan9FileConn) LocalAddr() net.Addr { + return plan9SrvAddr(fc.name) +} +func (fc plan9FileConn) RemoteAddr() net.Addr { + return plan9SrvAddr(fc.name) +} +func (fc plan9FileConn) SetDeadline(t time.Time) error { + return syscall.EPLAN9 +} +func (fc plan9FileConn) SetReadDeadline(t time.Time) error { + return syscall.EPLAN9 +} +func (fc plan9FileConn) SetWriteDeadline(t time.Time) error { + return syscall.EPLAN9 +} + +func connect(s *ConnectionStrategy) (net.Conn, error) { + f, err := os.OpenFile(s.path, os.O_RDWR, 0666) + if err != nil { + return nil, err + } + + return plan9FileConn{name: s.path, file: f}, nil +} + +// Create an entry in /srv, open a pipe, write the +// client end to the entry and return the server +// end of the pipe to the caller. When the server +// end of the pipe is closed, /srv name associated +// with it will be removed (controlled by ORCLOSE flag) +func listen(path string) (net.Listener, error) { + const O_RCLOSE = 64 // remove on close; should be in plan9 package + var pip [2]int + + err := plan9.Pipe(pip[:]) + if err != nil { + return nil, err + } + defer plan9.Close(pip[1]) + + srvfd, err := plan9.Create(path, plan9.O_WRONLY|plan9.O_CLOEXEC|O_RCLOSE, 0600) + if err != nil { + return nil, err + } + srv := os.NewFile(uintptr(srvfd), path) + + _, err = fmt.Fprintf(srv, "%d", pip[1]) + if err != nil { + return nil, err + } + + return &plan9SrvListener{name: path, srvf: srv, file: os.NewFile(uintptr(pip[0]), path)}, nil +} diff --git a/vendor/tailscale.com/safesocket/unixsocket.go b/vendor/tailscale.com/safesocket/unixsocket.go index a915927428..0a0dd485c5 100644 --- a/vendor/tailscale.com/safesocket/unixsocket.go +++ b/vendor/tailscale.com/safesocket/unixsocket.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !windows && !js +//go:build !windows && !js && !plan9 package safesocket diff --git a/vendor/tailscale.com/shell.nix b/vendor/tailscale.com/shell.nix index 8c3dfdb88f..c81b04e096 100644 --- a/vendor/tailscale.com/shell.nix +++ b/vendor/tailscale.com/shell.nix @@ -16,4 +16,4 @@ ) { src = ./.; }).shellNix -# nix-direnv cache busting line: sha256-fgCrmtJs1svFz0Xn7iwLNrbBNlcO6V0yqGPMY0+V1VQ= +# nix-direnv cache busting line: sha256-aVtlDzC+sbEWlUAzPkAryA/+dqSzoAFc02xikh6yhf8= diff --git a/vendor/tailscale.com/syncs/shardedmap.go b/vendor/tailscale.com/syncs/shardedmap.go new file mode 100644 index 0000000000..12edf5bfce --- /dev/null +++ b/vendor/tailscale.com/syncs/shardedmap.go @@ -0,0 +1,138 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +import ( + "sync" + + "golang.org/x/sys/cpu" +) + +// ShardedMap is a synchronized map[K]V, internally sharded by a user-defined +// K-sharding function. +// +// The zero value is not safe for use; use NewShardedMap. +type ShardedMap[K comparable, V any] struct { + shardFunc func(K) int + shards []mapShard[K, V] +} + +type mapShard[K comparable, V any] struct { + mu sync.Mutex + m map[K]V + _ cpu.CacheLinePad // avoid false sharing of neighboring shards' mutexes +} + +// NewShardedMap returns a new ShardedMap with the given number of shards and +// sharding function. +// +// The shard func must return a integer in the range [0, shards) purely +// deterministically based on the provided K. +func NewShardedMap[K comparable, V any](shards int, shard func(K) int) *ShardedMap[K, V] { + m := &ShardedMap[K, V]{ + shardFunc: shard, + shards: make([]mapShard[K, V], shards), + } + for i := range m.shards { + m.shards[i].m = make(map[K]V) + } + return m +} + +func (m *ShardedMap[K, V]) shard(key K) *mapShard[K, V] { + return &m.shards[m.shardFunc(key)] +} + +// GetOk returns m[key] and whether it was present. +func (m *ShardedMap[K, V]) GetOk(key K) (value V, ok bool) { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + value, ok = shard.m[key] + return +} + +// Get returns m[key] or the zero value of V if key is not present. +func (m *ShardedMap[K, V]) Get(key K) (value V) { + value, _ = m.GetOk(key) + return +} + +// Mutate atomically mutates m[k] by calling mutator. +// +// The mutator function is called with the old value (or its zero value) and +// whether it existed in the map and it returns the new value and whether it +// should be set in the map (true) or deleted from the map (false). +// +// It returns the change in size of the map as a result of the mutation, one of +// -1 (delete), 0 (change), or 1 (addition). +func (m *ShardedMap[K, V]) Mutate(key K, mutator func(oldValue V, oldValueExisted bool) (newValue V, keep bool)) (sizeDelta int) { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + oldV, oldOK := shard.m[key] + newV, newOK := mutator(oldV, oldOK) + if newOK { + shard.m[key] = newV + if oldOK { + return 0 + } + return 1 + } + delete(shard.m, key) + if oldOK { + return -1 + } + return 0 +} + +// Set sets m[key] = value. +// +// present in m). +func (m *ShardedMap[K, V]) Set(key K, value V) (grew bool) { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + s0 := len(shard.m) + shard.m[key] = value + return len(shard.m) > s0 +} + +// Delete removes key from m. +// +// It reports whether the map size shrunk (that is, whether key was present in +// the map). +func (m *ShardedMap[K, V]) Delete(key K) (shrunk bool) { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + s0 := len(shard.m) + delete(shard.m, key) + return len(shard.m) < s0 +} + +// Contains reports whether m contains key. +func (m *ShardedMap[K, V]) Contains(key K) bool { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + _, ok := shard.m[key] + return ok +} + +// Len returns the number of elements in m. +// +// It does so by locking shards one at a time, so it's not particularly cheap, +// nor does it give a consistent snapshot of the map. It's mostly intended for +// metrics or testing. +func (m *ShardedMap[K, V]) Len() int { + n := 0 + for i := range m.shards { + shard := &m.shards[i] + shard.mu.Lock() + n += len(shard.m) + shard.mu.Unlock() + } + return n +} diff --git a/vendor/tailscale.com/syncs/syncs.go b/vendor/tailscale.com/syncs/syncs.go index 79acd8654c..abfaba5d88 100644 --- a/vendor/tailscale.com/syncs/syncs.go +++ b/vendor/tailscale.com/syncs/syncs.go @@ -227,6 +227,13 @@ func (m *Map[K, V]) Len() int { return len(m.m) } +// Clear removes all entries from the map. +func (m *Map[K, V]) Clear() { + m.mu.Lock() + defer m.mu.Unlock() + clear(m.m) +} + // WaitGroup is identical to [sync.WaitGroup], // but provides a Go method to start a goroutine. type WaitGroup struct{ sync.WaitGroup } diff --git a/vendor/tailscale.com/tailcfg/c2ntypes.go b/vendor/tailscale.com/tailcfg/c2ntypes.go index 0578299d73..44f3ac70c0 100644 --- a/vendor/tailscale.com/tailcfg/c2ntypes.go +++ b/vendor/tailscale.com/tailcfg/c2ntypes.go @@ -33,3 +33,22 @@ type C2NSSHUsernamesResponse struct { // just a best effort set of hints. Usernames []string } + +// C2NUpdateResponse is the response (from node to control) from the /update +// handler. It tells control the status of its request for the node to update +// its Tailscale installation. +type C2NUpdateResponse struct { + // Err is the error message, if any. + Err string `json:",omitempty"` + + // Enabled indicates whether the user has opted in to updates triggered from + // control. + Enabled bool + + // Supported indicates whether remote updates are supported on this + // OS/platform. + Supported bool + + // Started indicates whether the update has started. + Started bool +} diff --git a/vendor/tailscale.com/tailcfg/derpmap.go b/vendor/tailscale.com/tailcfg/derpmap.go index abc763e47e..d95d26d57d 100644 --- a/vendor/tailscale.com/tailcfg/derpmap.go +++ b/vendor/tailscale.com/tailcfg/derpmap.go @@ -7,6 +7,11 @@ import "sort" // DERPMap describes the set of DERP packet relay servers that are available. type DERPMap struct { + // HomeParams, if non-nil, is a change in home parameters. + // + // The rest of the DEPRMap fields, if zero, means unchanged. + HomeParams *DERPHomeParams `json:",omitempty"` + // Regions is the set of geographic regions running DERP node(s). // // It's keyed by the DERPRegion.RegionID. @@ -16,6 +21,8 @@ type DERPMap struct { // OmitDefaultRegions specifies to not use Tailscale's DERP servers, and only use those // specified in this DERPMap. If there are none set outside of the defaults, this is a noop. + // + // This field is only meaningful if the Regions map is non-nil (indicating a change). OmitDefaultRegions bool `json:"omitDefaultRegions,omitempty"` } @@ -29,6 +36,25 @@ func (m *DERPMap) RegionIDs() []int { return ret } +// DERPHomeParams contains parameters from the server related to selecting a +// DERP home region (sometimes referred to as the "preferred DERP"). +type DERPHomeParams struct { + // RegionScore scales latencies of DERP regions by a given scaling + // factor when determining which region to use as the home + // ("preferred") DERP. Scores in the range (0, 1) will cause this + // region to be proportionally more preferred, and scores in the range + // (1, ∞) will penalize a region. + // + // If a region is not present in this map, it is treated as having a + // score of 1.0. + // + // Scores should not be 0 or negative; such scores will be ignored. + // + // A nil map means no change from the previous value (if any); an empty + // non-nil map can be sent to reset all scores back to 1.0. + RegionScore map[int]float64 `json:",omitempty"` +} + // DERPRegion is a geographic region running DERP relay node(s). // // Client nodes discover which region they're closest to, advertise diff --git a/vendor/tailscale.com/tailcfg/tailcfg.go b/vendor/tailscale.com/tailcfg/tailcfg.go index 4c6f9a39f7..180f588a36 100644 --- a/vendor/tailscale.com/tailcfg/tailcfg.go +++ b/vendor/tailscale.com/tailcfg/tailcfg.go @@ -3,23 +3,26 @@ package tailcfg -//go:generate go run tailscale.com/cmd/viewer --type=User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan --clonefunc +//go:generate go run tailscale.com/cmd/viewer --type=User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,RegisterResponseAuth,RegisterRequest,DERPHomeParams,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan,Location,UserProfile --clonefunc import ( "bytes" - "encoding/hex" + "encoding/json" "errors" "fmt" "net/netip" "reflect" + "slices" "strings" "time" + "golang.org/x/exp/maps" "tailscale.com/types/dnstype" "tailscale.com/types/key" "tailscale.com/types/opt" "tailscale.com/types/structs" "tailscale.com/types/tkatype" + "tailscale.com/util/cmpx" "tailscale.com/util/dnsname" ) @@ -100,7 +103,18 @@ type CapabilityVersion int // - 61: 2023-04-18: Client understand SSHAction.SSHRecorderFailureAction // - 62: 2023-05-05: Client can notify control over noise for SSHEventNotificationRequest recording failure events // - 63: 2023-06-08: Client understands SSHAction.AllowRemotePortForwarding. -const CurrentCapabilityVersion CapabilityVersion = 63 +// - 64: 2023-07-11: Client understands s/CapabilityTailnetLockAlpha/CapabilityTailnetLock +// - 65: 2023-07-12: Client understands DERPMap.HomeParams + incremental DERPMap updates with params +// - 66: 2023-07-23: UserProfile.Groups added (available via WhoIs) +// - 67: 2023-07-25: Client understands PeerCapMap +// - 68: 2023-08-09: Client has dedicated updateRoutine; MapRequest.Stream true means ignore Hostinfo+Endpoints +// - 69: 2023-08-16: removed Debug.LogHeap* + GoroutineDumpURL; added c2n /debug/logheap +// - 70: 2023-08-16: removed most Debug fields; added NodeAttrDisable*, NodeAttrDebug* instead +// - 71: 2023-08-17: added NodeAttrOneCGNATEnable, NodeAttrOneCGNATDisable +// - 72: 2023-08-23: TS-2023-006 UPnP issue fixed; UPnP can now be used again +// - 73: 2023-09-01: Non-Windows clients expect to receive ClientVersion +// - 74: 2023-09-18: Client understands NodeCapMap +const CurrentCapabilityVersion CapabilityVersion = 74 type StableID string @@ -146,7 +160,6 @@ type User struct { LoginName string `json:"-"` // not stored, filled from Login // TODO REMOVE DisplayName string // if non-empty overrides Login field ProfilePicURL string // if non-empty overrides Login field - Domain string Logins []LoginID Created time.Time } @@ -158,7 +171,6 @@ type Login struct { LoginName string DisplayName string ProfilePicURL string - Domain string } // A UserProfile is display-friendly data for a user. @@ -173,6 +185,27 @@ type UserProfile struct { // Roles exists for legacy reasons, to keep old macOS clients // happy. It JSON marshals as []. Roles emptyStructJSONSlice + + // Groups contains group identifiers for any group that this user is + // a part of and that the coordination server is configured to tell + // your node about. (Thus, it may be empty or incomplete.) + // There's no semantic difference between a nil and an empty list. + // The list is always sorted. + Groups []string `json:",omitempty"` +} + +func (p *UserProfile) Equal(p2 *UserProfile) bool { + if p == nil && p2 == nil { + return true + } + if p == nil || p2 == nil { + return false + } + return p.ID == p2.ID && + p.LoginName == p2.LoginName && + p.DisplayName == p2.DisplayName && + p.ProfilePicURL == p2.ProfilePicURL && + (len(p.Groups) == 0 && len(p2.Groups) == 0 || reflect.DeepEqual(p.Groups, p2.Groups)) } type emptyStructJSONSlice struct{} @@ -185,6 +218,31 @@ func (emptyStructJSONSlice) MarshalJSON() ([]byte, error) { func (emptyStructJSONSlice) UnmarshalJSON([]byte) error { return nil } +// RawMessage is a raw encoded JSON value. It implements Marshaler and +// Unmarshaler and can be used to delay JSON decoding or precompute a JSON +// encoding. +// +// It is like json.RawMessage but is a string instead of a []byte to better +// portray immutable data. +type RawMessage string + +// MarshalJSON returns m as the JSON encoding of m. +func (m RawMessage) MarshalJSON() ([]byte, error) { + if m == "" { + return []byte("null"), nil + } + return []byte(m), nil +} + +// UnmarshalJSON sets *m to a copy of data. +func (m *RawMessage) UnmarshalJSON(data []byte) error { + if m == nil { + return errors.New("RawMessage: UnmarshalJSON on nil pointer") + } + *m = RawMessage(data) + return nil +} + type Node struct { ID NodeID StableID StableNodeID @@ -195,9 +253,9 @@ type Node struct { // e.g. "host.tail-scale.ts.net." Name string - // User is the user who created the node. If ACL tags are in - // use for the node then it doesn't reflect the ACL identity - // that the node is running as. + // User is the user who created the node. If ACL tags are in use for the + // node then it doesn't reflect the ACL identity that the node is running + // as. User UserID // Sharer, if non-zero, is the user who shared this node, if different than User. @@ -211,10 +269,20 @@ type Node struct { Addresses []netip.Prefix // IP addresses of this Node directly AllowedIPs []netip.Prefix // range of IP addresses to route to this node Endpoints []string `json:",omitempty"` // IP+port (public via STUN, and local LANs) - DERP string `json:",omitempty"` // DERP-in-IP:port ("127.3.3.40:N") endpoint - Hostinfo HostinfoView - Created time.Time - Cap CapabilityVersion `json:",omitempty"` // if non-zero, the node's capability version; old servers might not send + + // DERP is this node's home DERP region ID integer, but shoved into an + // IP:port string for legacy reasons. The IP address is always "127.3.3.40" + // (a loopback address (127) followed by the digits over the letters DERP on + // a QWERTY keyboard (3.3.40)). The "port number" is the home DERP region ID + // integer. + // + // TODO(bradfitz): simplify this legacy mess; add a new HomeDERPRegionID int + // field behind a new capver bump. + DERP string `json:",omitempty"` // DERP-in-IP:port ("127.3.3.40:N") endpoint + + Hostinfo HostinfoView + Created time.Time + Cap CapabilityVersion `json:",omitempty"` // if non-zero, the node's capability version; old servers might not send // Tags are the list of ACL tags applied to this node. // Tags take the form of `tag:` where value starts @@ -242,8 +310,6 @@ type Node struct { // current node doesn't have permission to know. Online *bool `json:",omitempty"` - KeepAlive bool `json:",omitempty"` // open and keep open a connection to this peer - MachineAuthorized bool `json:",omitempty"` // TODO(crawshaw): replace with MachineStatus // Capabilities are capabilities that the node has. @@ -251,7 +317,21 @@ type Node struct { // such as: // "https://tailscale.com/cap/is-admin" // "https://tailscale.com/cap/file-sharing" - Capabilities []string `json:",omitempty"` + // + // Deprecated: use CapMap instead. + Capabilities []NodeCapability `json:",omitempty"` + + // CapMap is a map of capabilities to their optional argument/data values. + // + // It is valid for a capability to not have any argument/data values; such + // capabilities can be tested for using the HasCap method. These type of + // capabilities are used to indicate that a node has a capability, but there + // is no additional data associated with it. These were previously + // represented by the Capabilities field, but can now be represented by + // CapMap with an empty value. + // + // See NodeCapability for more information on keys. + CapMap NodeCapMap `json:",omitempty"` // UnsignedPeerAPIOnly means that this node is not signed nor subject to TKA // restrictions. However, in exchange for that privilege, it does not get @@ -294,11 +374,41 @@ type Node struct { // not be masqueraded (e.g. in case of --snat-subnet-routes). SelfNodeV4MasqAddrForThisPeer *netip.Addr `json:",omitempty"` + // SelfNodeV6MasqAddrForThisPeer is the IPv6 that this peer knows the current node as. + // It may be empty if the peer knows the current node by its native + // IPv6 address. + // This field is only populated in a MapResponse for peers and not + // for the current node. + // + // If set, it should be used to masquerade traffic originating from the + // current node to this peer. The masquerade address is only relevant + // for this peer and not for other peers. + // + // This only applies to traffic originating from the current node to the + // peer or any of its subnets. Traffic originating from subnet routes will + // not be masqueraded (e.g. in case of --snat-subnet-routes). + SelfNodeV6MasqAddrForThisPeer *netip.Addr `json:",omitempty"` + // IsWireGuardOnly indicates that this is a non-Tailscale WireGuard peer, it // is not expected to speak Disco or DERP, and it must have Endpoints in - // order to be reachable. TODO(#7826): 2023-04-06: only the first parseable - // Endpoint is used, see #7826 for updates. + // order to be reachable. IsWireGuardOnly bool `json:",omitempty"` + + // ExitNodeDNSResolvers is the list of DNS servers that should be used when this + // node is marked IsWireGuardOnly and being used as an exit node. + ExitNodeDNSResolvers []*dnstype.Resolver `json:",omitempty"` +} + +// HasCap reports whether the node has the given capability. +// It is safe to call on an invalid NodeView. +func (v NodeView) HasCap(cap NodeCapability) bool { + return v.ж.HasCap(cap) +} + +// HasCap reports whether the node has the given capability. +// It is safe to call on a nil Node. +func (v *Node) HasCap(cap NodeCapability) bool { + return v != nil && (v.CapMap.Contains(cap) || slices.Contains(v.Capabilities, cap)) } // DisplayName returns the user-facing name for a node which should @@ -350,6 +460,20 @@ func (n *Node) IsTagged() bool { return len(n.Tags) > 0 } +// SharerOrUser Sharer if set, else User. +func (n *Node) SharerOrUser() UserID { + return cmpx.Or(n.Sharer, n.User) +} + +// IsTagged reports whether the node has any tags. +func (n NodeView) IsTagged() bool { return n.ж.IsTagged() } + +// DisplayName wraps Node.DisplayName. +func (n NodeView) DisplayName(forOwner bool) string { return n.ж.DisplayName(forOwner) } + +// SharerOrUser wraps Node.SharerOrUser. +func (n NodeView) SharerOrUser() UserID { return n.ж.SharerOrUser() } + // InitDisplayNames computes and populates n's display name // fields: n.ComputedName, n.computedHostIfDifferent, and // n.ComputedNameWithHost. @@ -393,6 +517,10 @@ const ( MachineInvalid // server has explicitly rejected this machine key ) +func (m MachineStatus) AppendText(b []byte) ([]byte, error) { + return append(b, m.String()...), nil +} + func (m MachineStatus) MarshalText() ([]byte, error) { return []byte(m.String()), nil } @@ -531,6 +659,31 @@ type Service struct { // TODO(apenwarr): add "tags" here for each service? } +// Location represents geographical location data about a +// Tailscale host. Location is optional and only set if +// explicitly declared by a node. +type Location struct { + Country string `json:",omitempty"` // User friendly country name, with proper capitalization ("Canada") + CountryCode string `json:",omitempty"` // ISO 3166-1 alpha-2 in upper case ("CA") + City string `json:",omitempty"` // User friendly city name, with proper capitalization ("Squamish") + + // CityCode is a short code representing the city in upper case. + // CityCode is used to disambiguate a city from another location + // with the same city name. It uniquely identifies a particular + // geographical location, within the tailnet. + // IATA, ICAO or ISO 3166-2 codes are recommended ("YSE") + CityCode string `json:",omitempty"` + + // Priority determines the order of use of an exit node when a + // location based preference matches more than one exit node, + // the node with the highest priority wins. Nodes of equal + // probability may be selected arbitrarily. + // + // A value of 0 means the exit node does not have a priority + // preference. A negative int is not allowed. + Priority int `json:",omitempty"` +} + // Hostinfo contains a summary of a Tailscale host. // // Because it contains pointers (slices), this type should not be used @@ -585,6 +738,11 @@ type Hostinfo struct { Userspace opt.Bool `json:",omitempty"` // if the client is running in userspace (netstack) mode UserspaceRouter opt.Bool `json:",omitempty"` // if the client's subnet router is running in userspace (netstack) mode + // Location represents geographical location data about a + // Tailscale host. Location is optional and only set if + // explicitly declared by a node. + Location *Location `json:",omitempty"` + // NOTE: any new fields containing pointers in this type // require changes to Hostinfo.Equal. } @@ -647,11 +805,12 @@ type NetInfo struct { // Empty means not checked. PCP opt.Bool - // PreferredDERP is this node's preferred DERP server - // for incoming traffic. The node might be be temporarily - // connected to multiple DERP servers (to send to other nodes) - // but PreferredDERP is the instance number that the node - // subscribes to traffic at. + // PreferredDERP is this node's preferred (home) DERP region ID. + // This is where the node expects to be contacted to begin a + // peer-to-peer connection. The node might be be temporarily + // connected to multiple DERP servers (to speak to other nodes + // that are located elsewhere) but PreferredDERP is the region ID + // that the node subscribes to traffic at. // Zero means disconnected or unknown. PreferredDERP int @@ -668,6 +827,14 @@ type NetInfo struct { // the control plane. DERPLatency map[string]float64 `json:",omitempty"` + // FirewallMode encodes both which firewall mode was selected and why. + // It is Linux-specific (at least as of 2023-08-19) and is meant to help + // debug iptables-vs-nftables issues. The string is of the form + // "{nft,ift}-REASON", like "nft-forced" or "ipt-default". Empty means + // either not Linux or a configuration in which the host firewall rules + // are not managed by tailscaled. + FirewallMode string `json:",omitempty"` + // Update BasicallyEqual when adding fields. } @@ -675,10 +842,10 @@ func (ni *NetInfo) String() string { if ni == nil { return "NetInfo(nil)" } - return fmt.Sprintf("NetInfo{varies=%v hairpin=%v ipv6=%v ipv6os=%v udp=%v icmpv4=%v derp=#%v portmap=%v link=%q}", + return fmt.Sprintf("NetInfo{varies=%v hairpin=%v ipv6=%v ipv6os=%v udp=%v icmpv4=%v derp=#%v portmap=%v link=%q firewallmode=%q}", ni.MappingVariesByDestIP, ni.HairPinning, ni.WorkingIPv6, ni.OSHasIPv6, ni.WorkingUDP, ni.WorkingICMPv4, - ni.PreferredDERP, ni.portMapSummary(), ni.LinkType) + ni.PreferredDERP, ni.portMapSummary(), ni.LinkType, ni.FirewallMode) } func (ni *NetInfo) portMapSummary() string { @@ -726,7 +893,8 @@ func (ni *NetInfo) BasicallyEqual(ni2 *NetInfo) bool { ni.PMP == ni2.PMP && ni.PCP == ni2.PCP && ni.PreferredDERP == ni2.PreferredDERP && - ni.LinkType == ni2.LinkType + ni.LinkType == ni2.LinkType && + ni.FirewallMode == ni2.FirewallMode } // Equal reports whether h and h2 are equal. @@ -829,6 +997,10 @@ const ( SignatureV2 ) +func (st SignatureType) AppendText(b []byte) ([]byte, error) { + return append(b, st.String()...), nil +} + func (st SignatureType) MarshalText() ([]byte, error) { return []byte(st.String()), nil } @@ -867,6 +1039,16 @@ func (st SignatureType) String() string { } } +// RegisterResponseAuth is the authentication information returned by the server +// in response to a RegisterRequest. +type RegisterResponseAuth struct { + _ structs.Incomparable + // One of Provider/LoginName, Oauth2Token, or AuthKey is set. + Provider, LoginName string + Oauth2Token *Oauth2Token + AuthKey string +} + // RegisterRequest is sent by a client to register the key for a node. // It is encoded to JSON, encrypted with golang.org/x/crypto/nacl/box, // using the local machine key, and sent to: @@ -885,13 +1067,7 @@ type RegisterRequest struct { NodeKey key.NodePublic OldNodeKey key.NodePublic NLKey key.NLPublic - Auth struct { - _ structs.Incomparable - // One of Provider/LoginName, Oauth2Token, or AuthKey is set. - Provider, LoginName string - Oauth2Token *Oauth2Token - AuthKey string - } + Auth RegisterResponseAuth // Expiry optionally specifies the requested key expiry. // The server policy may override. // As a special case, if Expiry is in the past and NodeKey is @@ -920,29 +1096,6 @@ type RegisterRequest struct { Signature []byte `json:",omitempty"` // as described by SignatureType } -// Clone makes a deep copy of RegisterRequest. -// The result aliases no memory with the original. -// -// TODO: extend cmd/cloner to generate this method. -func (req *RegisterRequest) Clone() *RegisterRequest { - if req == nil { - return nil - } - res := new(RegisterRequest) - *res = *req - if res.Hostinfo != nil { - res.Hostinfo = res.Hostinfo.Clone() - } - if res.Auth.Oauth2Token != nil { - tok := *res.Auth.Oauth2Token - res.Auth.Oauth2Token = &tok - } - res.DeviceCert = append(res.DeviceCert[:0:0], res.DeviceCert...) - res.Signature = append(res.Signature[:0:0], res.Signature...) - res.NodeKeySignature = append(res.NodeKeySignature[:0:0], res.NodeKeySignature...) - return res -} - // RegisterResponse is returned by the server in response to a RegisterRequest. type RegisterResponse struct { User User @@ -996,7 +1149,9 @@ type Endpoint struct { Type EndpointType } -// MapRequest is sent by a client to start a long-poll network map updates. +// MapRequest is sent by a client to either update the control plane +// about its current state, or to start a long-poll of network map updates. +// // The request includes a copy of the client's current set of WireGuard // endpoints and general host information. // @@ -1012,13 +1167,25 @@ type MapRequest struct { // For current values and history, see the CapabilityVersion type's docs. Version CapabilityVersion - Compress string // "zstd" or "" (no compression) - KeepAlive bool // whether server should send keep-alives back to us - NodeKey key.NodePublic - DiscoKey key.DiscoPublic - IncludeIPv6 bool `json:",omitempty"` // include IPv6 endpoints in returned Node Endpoints (for Version 4 clients) - Stream bool // if true, multiple MapResponse objects are returned - Hostinfo *Hostinfo + Compress string // "zstd" or "" (no compression) + KeepAlive bool // whether server should send keep-alives back to us + NodeKey key.NodePublic + DiscoKey key.DiscoPublic + + // Stream is whether the client wants to receive multiple MapResponses over + // the same HTTP connection. + // + // If false, the server will send a single MapResponse and then close the + // connection. + // + // If true and Version >= 68, the server should treat this as a read-only + // request and ignore any Hostinfo or other fields that might be set. + Stream bool + + // Hostinfo is the client's current Hostinfo. Although it is always included + // in the request, the server may choose to ignore it when Stream is true + // and Version >= 68. + Hostinfo *Hostinfo // MapSessionHandle, if non-empty, is a request to reattach to a previous // map session after a previous map session was interrupted for whatever @@ -1040,6 +1207,7 @@ type MapRequest struct { MapSessionSeq int64 `json:",omitempty"` // Endpoints are the client's magicsock UDP ip:port endpoints (IPv4 or IPv6). + // These can be ignored if Stream is true and Version >= 68. Endpoints []string // EndpointTypes are the types of the corresponding endpoints in Endpoints. EndpointTypes []EndpointType `json:",omitempty"` @@ -1049,13 +1217,12 @@ type MapRequest struct { // It is encoded as tka.AUMHash.MarshalText. TKAHead string `json:",omitempty"` - // ReadOnly is whether the client just wants to fetch the - // MapResponse, without updating their Endpoints. The - // Endpoints field will be ignored and LastSeen will not be - // updated and peers will not be notified of changes. + // ReadOnly was set when client just wanted to fetch the MapResponse, + // without updating their Endpoints. The intended use was for clients to + // discover the DERP map at start-up before their first real endpoint + // update. // - // The intended use is for clients to discover the DERP map at - // start-up before their first real endpoint update. + // Deprecated: always false as of Version 68. ReadOnly bool `json:",omitempty"` // OmitPeers is whether the client is okay with the Peers list being omitted @@ -1119,7 +1286,110 @@ type CapGrant struct { // Caps are the capabilities the source IP matched by // FilterRule.SrcIPs are granted to the destination IP, // matched by Dsts. - Caps []string `json:",omitempty"` + // Deprecated: use CapMap instead. + Caps []PeerCapability `json:",omitempty"` + + // CapMap is a map of capabilities to their values. + // The key is the capability name, and the value is a list of + // values for that capability. + CapMap PeerCapMap `json:",omitempty"` +} + +// PeerCapability represents a capability granted to a peer by a FilterRule when +// the peer communicates with the node that has this rule. Its meaning is +// application-defined. +// +// It must be a URL like "https://tailscale.com/cap/file-send". +type PeerCapability string + +const ( + // PeerCapabilityFileSharingTarget grants the current node the ability to send + // files to the peer which has this capability. + PeerCapabilityFileSharingTarget PeerCapability = "https://tailscale.com/cap/file-sharing-target" + // PeerCapabilityFileSharingSend grants the ability to receive files from a + // node that's owned by a different user. + PeerCapabilityFileSharingSend PeerCapability = "https://tailscale.com/cap/file-send" + // PeerCapabilityDebugPeer grants the ability for a peer to read this node's + // goroutines, metrics, magicsock internal state, etc. + PeerCapabilityDebugPeer PeerCapability = "https://tailscale.com/cap/debug-peer" + // PeerCapabilityWakeOnLAN grants the ability to send a Wake-On-LAN packet. + PeerCapabilityWakeOnLAN PeerCapability = "https://tailscale.com/cap/wake-on-lan" + // PeerCapabilityIngress grants the ability for a peer to send ingress traffic. + PeerCapabilityIngress PeerCapability = "https://tailscale.com/cap/ingress" +) + +// NodeCapMap is a map of capabilities to their optional values. It is valid for +// a capability to have no values (nil slice); such capabilities can be tested +// for by using the Contains method. +// +// See [NodeCapability] for more information on keys. +type NodeCapMap map[NodeCapability][]RawMessage + +// Equal reports whether c and c2 are equal. +func (c NodeCapMap) Equal(c2 NodeCapMap) bool { + return maps.EqualFunc(c, c2, slices.Equal) +} + +// UnmarshalNodeCapJSON unmarshals each JSON value in cm[cap] as T. +// If cap does not exist in cm, it returns (nil, nil). +// It returns an error if the values cannot be unmarshaled into the provided type. +func UnmarshalNodeCapJSON[T any](cm NodeCapMap, cap NodeCapability) ([]T, error) { + vals, ok := cm[cap] + if !ok { + return nil, nil + } + out := make([]T, 0, len(vals)) + for _, v := range vals { + var t T + if err := json.Unmarshal([]byte(v), &t); err != nil { + return nil, err + } + out = append(out, t) + } + return out, nil +} + +// Contains reports whether c has the capability cap. This is used to test for +// the existence of a capability, especially when the capability has no +// associated argument/data values. +func (c NodeCapMap) Contains(cap NodeCapability) bool { + _, ok := c[cap] + return ok +} + +// PeerCapMap is a map of capabilities to their optional values. It is valid for +// a capability to have no values (nil slice); such capabilities can be tested +// for by using the HasCapability method. +// +// The values are opaque to Tailscale, but are passed through from the ACLs to +// the application via the WhoIs API. +type PeerCapMap map[PeerCapability][]RawMessage + +// UnmarshalCapJSON unmarshals each JSON value in cm[cap] as T. +// If cap does not exist in cm, it returns (nil, nil). +// It returns an error if the values cannot be unmarshaled into the provided type. +func UnmarshalCapJSON[T any](cm PeerCapMap, cap PeerCapability) ([]T, error) { + vals, ok := cm[cap] + if !ok { + return nil, nil + } + out := make([]T, 0, len(vals)) + for _, v := range vals { + var t T + if err := json.Unmarshal([]byte(v), &t); err != nil { + return nil, err + } + out = append(out, t) + } + return out, nil +} + +// HasCapability reports whether c has the capability cap. This is used to test +// for the existence of a capability, especially when the capability has no +// associated argument/data values. +func (c PeerCapMap) HasCapability(cap PeerCapability) bool { + _, ok := c[cap] + return ok } // FilterRule represents one rule in a packet filter. @@ -1254,7 +1524,12 @@ type DNSConfig struct { // match. // // Matches are case insensitive. - ExitNodeFilteredSet []string + ExitNodeFilteredSet []string `json:",omitempty"` + + // TempCorpIssue13969 is a temporary (2023-08-16) field for an internal hack day prototype. + // It contains a user inputed URL that should have a list of domains to be blocked. + // See https://github.com/tailscale/corp/issues/13969. + TempCorpIssue13969 string `json:",omitempty"` } // DNSRecord is an extra DNS record to add to MagicDNS. @@ -1377,6 +1652,27 @@ type PingResponse struct { IsLocalIP bool `json:",omitempty"` } +// MapResponse is the response to a MapRequest. It describes the state of the +// local node, the peer nodes, the DNS configuration, the packet filter, and +// more. A MapRequest, depending on its parameters, may result in the control +// plane coordination server sending 0, 1 or a stream of multiple MapResponse +// values. +// +// When the client sets MapRequest.Stream, the server sends a stream of +// MapResponses. That long-lived HTTP transaction is called a "map poll". In a +// map poll, the first MapResponse will be complete and subsequent MapResponses +// will be incremental updates with only changed information. +// +// The zero value for all fields means "unchanged". Unfortunately, several +// fields were defined before that convention was established, so they use a +// slice with omitempty, meaning this type can't be used to marshal JSON +// containing non-nil zero-length slices (meaning explicitly now empty). The +// control plane uses a separate type to marshal these fields. This type is +// primarily used for unmarshaling responses so the omitempty annotations are +// mostly useless, except that this type is also used for the integration test's +// fake control server. (It's not necessary to marshal a non-nil zero-length +// slice for the things we've needed to test in the integration tests as of +// 2023-09-09). type MapResponse struct { // MapSessionHandle optionally specifies a unique opaque handle for this // stateful MapResponse session. Servers may choose not to send it, and it's @@ -1477,6 +1773,10 @@ type MapResponse struct { // previously streamed non-nil MapResponse.PacketFilter within // the same HTTP response. A non-nil but empty list always means // no PacketFilter (that is, to block everything). + // + // Note that this package's type, due its use of a slice and omitempty, is + // unable to marshal a zero-length non-nil slice. The control server needs + // to marshal this type using a separate type. See MapResponse docs. PacketFilter []FilterRule `json:",omitempty"` // UserProfiles are the user profiles of nodes in the network. @@ -1484,12 +1784,15 @@ type MapResponse struct { // user profiles only. UserProfiles []UserProfile `json:",omitempty"` - // Health, if non-nil, sets the health state - // of the node from the control plane's perspective. - // A nil value means no change from the previous MapResponse. - // A non-nil 0-length slice restores the health to good (no known problems). - // A non-zero length slice are the list of problems that the control place - // sees. + // Health, if non-nil, sets the health state of the node from the control + // plane's perspective. A nil value means no change from the previous + // MapResponse. A non-nil 0-length slice restores the health to good (no + // known problems). A non-zero length slice are the list of problems that + // the control place sees. + // + // Note that this package's type, due its use of a slice and omitempty, is + // unable to marshal a zero-length non-nil slice. The control server needs + // to marshal this type using a separate type. See MapResponse docs. Health []string `json:",omitempty"` // SSHPolicy, if non-nil, updates the SSH policy for how incoming @@ -1586,106 +1889,31 @@ type ControlIPCandidate struct { Priority int `json:",omitempty"` } -// Debug are instructions from the control server to the client -// to adjust debug settings. +// Debug used to be a miscellaneous set of declarative debug config changes and +// imperative debug commands. They've since been mostly migrated to node +// attributes (MapResponse.Node.Capabilities) for the declarative things and c2n +// requests for the imperative things. Not much remains here. Don't add more. type Debug struct { - // LogHeapPprof controls whether the client should log - // its heap pprof data. Each true value sent from the server - // means that client should do one more log. - LogHeapPprof bool `json:",omitempty"` - - // LogHeapURL is the URL to POST its heap pprof to. - // Empty means to not log. - LogHeapURL string `json:",omitempty"` - - // ForceBackgroundSTUN controls whether magicsock should - // always do its background STUN queries (see magicsock's - // periodicReSTUN), regardless of inactivity. - ForceBackgroundSTUN bool `json:",omitempty"` - - // SetForceBackgroundSTUN controls whether magicsock should always do its - // background STUN queries (see magicsock's periodicReSTUN), regardless of - // inactivity. - // - // As of capver 37, this field is the preferred field for control to set on - // the wire and ForceBackgroundSTUN is only used within the code as the - // current map session value. But ForceBackgroundSTUN can still be used too. - SetForceBackgroundSTUN opt.Bool `json:",omitempty"` - - // DERPRoute controls whether the DERP reverse path - // optimization (see Issue 150) should be enabled or - // disabled. The environment variable in magicsock is the - // highest priority (if set), then this (if set), then the - // binary default value. - DERPRoute opt.Bool `json:",omitempty"` - - // TrimWGConfig controls whether Tailscale does lazy, on-demand - // wireguard configuration of peers. - TrimWGConfig opt.Bool `json:",omitempty"` - - // DisableSubnetsIfPAC controls whether subnet routers should be - // disabled if WPAD is present on the network. - DisableSubnetsIfPAC opt.Bool `json:",omitempty"` - - // GoroutineDumpURL, if non-empty, requests that the client do - // a one-time dump of its active goroutines to the given URL. - GoroutineDumpURL string `json:",omitempty"` - // SleepSeconds requests that the client sleep for the // provided number of seconds. // The client can (and should) limit the value (such as 5 - // minutes). + // minutes). This exists as a safety measure to slow down + // spinning clients, in case we introduce a bug in the + // state machine. SleepSeconds float64 `json:",omitempty"` - // RandomizeClientPort is whether magicsock should UDP bind to - // :0 to get a random local port, ignoring any configured - // fixed port. - RandomizeClientPort bool `json:",omitempty"` - - // SetRandomizeClientPort is whether magicsock should UDP bind to :0 to get - // a random local port, ignoring any configured fixed port. - // - // As of capver 37, this field is the preferred field for control to set on - // the wire and RandomizeClientPort is only used within the code as the - // current map session value. But RandomizeClientPort can still be used too. - SetRandomizeClientPort opt.Bool `json:",omitempty"` - - // OneCGNATRoute controls whether the client should prefer to make one - // big CGNAT /10 route rather than a /32 per peer. - OneCGNATRoute opt.Bool `json:",omitempty"` - - // DisableUPnP is whether the client will attempt to perform a UPnP portmapping. - // By default, we want to enable it to see if it works on more clients. - // - // If UPnP catastrophically fails for people, this should be set to True to kill - // new attempts at UPnP connections. - DisableUPnP opt.Bool `json:",omitempty"` - // DisableLogTail disables the logtail package. Once disabled it can't be // re-enabled for the lifetime of the process. + // + // This is primarily used by Headscale. DisableLogTail bool `json:",omitempty"` - // EnableSilentDisco disables the use of heartBeatTimer in magicsock and attempts to - // handle disco silently. See issue #540 for details. - EnableSilentDisco bool `json:",omitempty"` - // Exit optionally specifies that the client should os.Exit - // with this code. + // with this code. This is a safety measure in case a client is crash + // looping or in an unsafe state and we need to remotely shut it down. Exit *int `json:",omitempty"` } -func appendKey(base []byte, prefix string, k [32]byte) []byte { - ret := append(base, make([]byte, len(prefix)+64)...) - buf := ret[len(base):] - copy(buf, prefix) - hex.Encode(buf[len(prefix):], k[:]) - return ret -} - -func keyMarshalText(prefix string, k [32]byte) []byte { - return appendKey(nil, prefix, k) -} - func (id ID) String() string { return fmt.Sprintf("id:%x", int64(id)) } func (id UserID) String() string { return fmt.Sprintf("userid:%x", int64(id)) } func (id LoginID) String() string { return fmt.Sprintf("loginid:%x", int64(id)) } @@ -1719,13 +1947,15 @@ func (n *Node) Equal(n2 *Node) bool { n.Created.Equal(n2.Created) && eqTimePtr(n.LastSeen, n2.LastSeen) && n.MachineAuthorized == n2.MachineAuthorized && - eqStrings(n.Capabilities, n2.Capabilities) && + slices.Equal(n.Capabilities, n2.Capabilities) && + n.CapMap.Equal(n2.CapMap) && n.ComputedName == n2.ComputedName && n.computedHostIfDifferent == n2.computedHostIfDifferent && n.ComputedNameWithHost == n2.ComputedNameWithHost && eqStrings(n.Tags, n2.Tags) && n.Expired == n2.Expired && eqPtr(n.SelfNodeV4MasqAddrForThisPeer, n2.SelfNodeV4MasqAddrForThisPeer) && + eqPtr(n.SelfNodeV6MasqAddrForThisPeer, n2.SelfNodeV6MasqAddrForThisPeer) && n.IsWireGuardOnly == n2.IsWireGuardOnly } @@ -1792,89 +2022,121 @@ type Oauth2Token struct { Expiry time.Time `json:"expiry,omitempty"` } -const ( - // These are the capabilities that the self node has as listed in - // MapResponse.Node.Capabilities. - // - // We've since started referring to these as "Node Attributes" ("nodeAttrs" - // in the ACL policy file). +// NodeCapability represents a capability granted to the self node as listed in +// MapResponse.Node.Capabilities. +// +// It must be a URL like "https://tailscale.com/cap/file-sharing", or a +// well-known capability name like "funnel". The latter is only allowed for +// Tailscale-defined capabilities. +// +// Unlike PeerCapability, NodeCapability is not in context of a peer and is +// granted to the node itself. +// +// These are also referred to as "Node Attributes" in the ACL policy file. +type NodeCapability string - CapabilityFileSharing = "https://tailscale.com/cap/file-sharing" - CapabilityAdmin = "https://tailscale.com/cap/is-admin" - CapabilitySSH = "https://tailscale.com/cap/ssh" // feature enabled/available - CapabilitySSHRuleIn = "https://tailscale.com/cap/ssh-rule-in" // some SSH rule reach this node - CapabilityDataPlaneAuditLogs = "https://tailscale.com/cap/data-plane-audit-logs" // feature enabled - CapabilityDebug = "https://tailscale.com/cap/debug" // exposes debug endpoints over the PeerAPI +const ( + CapabilityFileSharing NodeCapability = "https://tailscale.com/cap/file-sharing" + CapabilityAdmin NodeCapability = "https://tailscale.com/cap/is-admin" + CapabilitySSH NodeCapability = "https://tailscale.com/cap/ssh" // feature enabled/available + CapabilitySSHRuleIn NodeCapability = "https://tailscale.com/cap/ssh-rule-in" // some SSH rule reach this node + CapabilityDataPlaneAuditLogs NodeCapability = "https://tailscale.com/cap/data-plane-audit-logs" // feature enabled + CapabilityDebug NodeCapability = "https://tailscale.com/cap/debug" // exposes debug endpoints over the PeerAPI + CapabilityHTTPS NodeCapability = "https" // https cert provisioning enabled on tailnet // CapabilityBindToInterfaceByRoute changes how Darwin nodes create // sockets (in the net/netns package). See that package for more // details on the behaviour of this capability. - CapabilityBindToInterfaceByRoute = "https://tailscale.com/cap/bind-to-interface-by-route" + CapabilityBindToInterfaceByRoute NodeCapability = "https://tailscale.com/cap/bind-to-interface-by-route" // CapabilityDebugDisableAlternateDefaultRouteInterface changes how Darwin // nodes get the default interface. There is an optional hook (used by the // macOS and iOS clients) to override the default interface, this capability // disables that and uses the default behavior (of parsing the routing // table). - CapabilityDebugDisableAlternateDefaultRouteInterface = "https://tailscale.com/cap/debug-disable-alternate-default-route-interface" + CapabilityDebugDisableAlternateDefaultRouteInterface NodeCapability = "https://tailscale.com/cap/debug-disable-alternate-default-route-interface" // CapabilityDebugDisableBindConnToInterface disables the automatic binding // of connections to the default network interface on Darwin nodes. - CapabilityDebugDisableBindConnToInterface = "https://tailscale.com/cap/debug-disable-bind-conn-to-interface" - - // CapabilityTailnetLockAlpha indicates the node is in the tailnet lock alpha, - // and initialization of tailnet lock may proceed. - // - // TODO(tom): Remove this for 1.35 and later. - CapabilityTailnetLockAlpha = "https://tailscale.com/cap/tailnet-lock-alpha" + CapabilityDebugDisableBindConnToInterface NodeCapability = "https://tailscale.com/cap/debug-disable-bind-conn-to-interface" - // Inter-node capabilities as specified in the MapResponse.PacketFilter[].CapGrants. - - // CapabilityFileSharingTarget grants the current node the ability to send - // files to the peer which has this capability. - CapabilityFileSharingTarget = "https://tailscale.com/cap/file-sharing-target" - // CapabilityFileSharingSend grants the ability to receive files from a - // node that's owned by a different user. - CapabilityFileSharingSend = "https://tailscale.com/cap/file-send" - // CapabilityDebugPeer grants the ability for a peer to read this node's - // goroutines, metrics, magicsock internal state, etc. - CapabilityDebugPeer = "https://tailscale.com/cap/debug-peer" - // CapabilityWakeOnLAN grants the ability to send a Wake-On-LAN packet. - CapabilityWakeOnLAN = "https://tailscale.com/cap/wake-on-lan" - // CapabilityIngress grants the ability for a peer to send ingress traffic. - CapabilityIngress = "https://tailscale.com/cap/ingress" - // CapabilitySSHSessionHaul grants the ability to receive SSH session logs - // from a peer. - CapabilitySSHSessionHaul = "https://tailscale.com/cap/ssh-session-haul" + // CapabilityTailnetLock indicates the node may initialize tailnet lock. + CapabilityTailnetLock NodeCapability = "https://tailscale.com/cap/tailnet-lock" // Funnel warning capabilities used for reporting errors to the user. // CapabilityWarnFunnelNoInvite indicates whether Funnel is enabled for the tailnet. - // NOTE: In transition from Alpha to Beta, this capability is being reused as the enablement. - CapabilityWarnFunnelNoInvite = "https://tailscale.com/cap/warn-funnel-no-invite" + // This cap is no longer used 2023-08-09 onwards. + CapabilityWarnFunnelNoInvite NodeCapability = "https://tailscale.com/cap/warn-funnel-no-invite" // CapabilityWarnFunnelNoHTTPS indicates HTTPS has not been enabled for the tailnet. - CapabilityWarnFunnelNoHTTPS = "https://tailscale.com/cap/warn-funnel-no-https" + // This cap is no longer used 2023-08-09 onwards. + CapabilityWarnFunnelNoHTTPS NodeCapability = "https://tailscale.com/cap/warn-funnel-no-https" // Debug logging capabilities // CapabilityDebugTSDNSResolution enables verbose debug logging for DNS // resolution for Tailscale-controlled domains (the control server, log // server, DERP servers, etc.) - CapabilityDebugTSDNSResolution = "https://tailscale.com/cap/debug-ts-dns-resolution" + CapabilityDebugTSDNSResolution NodeCapability = "https://tailscale.com/cap/debug-ts-dns-resolution" // CapabilityFunnelPorts specifies the ports that the Funnel is available on. // The ports are specified as a comma-separated list of port numbers or port // ranges (e.g. "80,443,8080-8090") in the ports query parameter. // e.g. https://tailscale.com/cap/funnel-ports?ports=80,443,8080-8090 - CapabilityFunnelPorts = "https://tailscale.com/cap/funnel-ports" -) + CapabilityFunnelPorts NodeCapability = "https://tailscale.com/cap/funnel-ports" -const ( // NodeAttrFunnel grants the ability for a node to host ingress traffic. - NodeAttrFunnel = "funnel" + NodeAttrFunnel NodeCapability = "funnel" // NodeAttrSSHAggregator grants the ability for a node to collect SSH sessions. - NodeAttrSSHAggregator = "ssh-aggregator" + NodeAttrSSHAggregator NodeCapability = "ssh-aggregator" + + // NodeAttrDebugForceBackgroundSTUN forces a node to always do background + // STUN queries regardless of inactivity. + NodeAttrDebugForceBackgroundSTUN NodeCapability = "debug-always-stun" + + // NodeAttrDebugDisableWGTrim disables the lazy WireGuard configuration, + // always giving WireGuard the full netmap, even for idle peers. + NodeAttrDebugDisableWGTrim NodeCapability = "debug-no-wg-trim" + + // NodeAttrDebugDisableDRPO disables the DERP Return Path Optimization. + // See Issue 150. + NodeAttrDebugDisableDRPO NodeCapability = "debug-disable-drpo" + + // NodeAttrDisableSubnetsIfPAC controls whether subnet routers should be + // disabled if WPAD is present on the network. + NodeAttrDisableSubnetsIfPAC NodeCapability = "debug-disable-subnets-if-pac" + + // NodeAttrDisableUPnP makes the client not perform a UPnP portmapping. + // By default, we want to enable it to see if it works on more clients. + // + // If UPnP catastrophically fails for people, this should be set kill + // new attempts at UPnP connections. + NodeAttrDisableUPnP NodeCapability = "debug-disable-upnp" + + // NodeAttrDisableDeltaUpdates makes the client not process updates via the + // delta update mechanism and should instead treat all netmap changes as + // "full" ones as tailscaled did in 1.48.x and earlier. + NodeAttrDisableDeltaUpdates NodeCapability = "disable-delta-updates" + + // NodeAttrRandomizeClientPort makes magicsock UDP bind to + // :0 to get a random local port, ignoring any configured + // fixed port. + NodeAttrRandomizeClientPort NodeCapability = "randomize-client-port" + + // NodeAttrOneCGNATEnable makes the client prefer one big CGNAT /10 route + // rather than a /32 per peer. At most one of this or + // NodeAttrOneCGNATDisable may be set; if neither are, it's automatic. + NodeAttrOneCGNATEnable NodeCapability = "one-cgnat?v=true" + + // NodeAttrOneCGNATDisable makes the client prefer a /32 route per peer + // rather than one big /10 CGNAT route. At most one of this or + // NodeAttrOneCGNATEnable may be set; if neither are, it's automatic. + NodeAttrOneCGNATDisable NodeCapability = "one-cgnat?v=false" + + // NodeAttrPeerMTUEnable makes the client do path MTU discovery to its + // peers. If it isn't set, it defaults to the client default. + NodeAttrPeerMTUEnable NodeCapability = "peer-mtu-enable" ) // SetDNSRequest is a request to add a DNS record. @@ -2143,6 +2405,51 @@ type SSHRecordingAttempt struct { FailureMessage string } +// QueryFeatureRequest is a request sent to "/machine/feature/query" +// to get instructions on how to enable a feature, such as Funnel, +// for the node's tailnet. +// +// See QueryFeatureResponse for response structure. +type QueryFeatureRequest struct { + // Feature is the string identifier for a feature. + Feature string `json:",omitempty"` + // NodeKey is the client's current node key. + NodeKey key.NodePublic `json:",omitempty"` +} + +// QueryFeatureResponse is the response to an QueryFeatureRequest. +// See cli.enableFeatureInteractive for usage. +type QueryFeatureResponse struct { + // Complete is true when the feature is already enabled. + Complete bool `json:",omitempty"` + + // Text holds lines to display in the CLI with information + // about the feature and how to enable it. + // + // Lines are separated by newline characters. The final + // newline may be omitted. + Text string `json:",omitempty"` + + // URL is the link for the user to visit to take action on + // enabling the feature. + // + // When empty, there is no action for this user to take. + URL string `json:",omitempty"` + + // ShouldWait specifies whether the CLI should block and + // wait for the user to enable the feature. + // + // If this is true, the enablement from the control server + // is expected to be a quick and uninterrupted process for + // the user, and blocking allows them to immediately start + // using the feature once enabled without rerunning the + // command (e.g. no need to re-run "funnel on"). + // + // The CLI can watch the IPN notification bus for changes in + // required node capabilities to know when to continue. + ShouldWait bool `json:",omitempty"` +} + // OverTLSPublicKeyResponse is the JSON response to /key?v= // over HTTPS (regular TLS) to the Tailscale control plane server, // where the 'v' argument is the client's current capability version @@ -2221,6 +2528,9 @@ type PeerChange struct { // Cap, if non-zero, means that NodeID's capability version has changed. Cap CapabilityVersion `json:",omitempty"` + // CapMap, if non-nil, means that NodeID's capability map has changed. + CapMap NodeCapMap `json:",omitempty"` + // Endpoints, if non-empty, means that NodeID's UDP Endpoints // have changed to these. Endpoints []string `json:",omitempty"` @@ -2247,7 +2557,7 @@ type PeerChange struct { // Capabilities, if non-nil, means that the NodeID's capabilities changed. // It's a pointer to a slice for "omitempty", to allow differentiating // a change to empty from no change. - Capabilities *[]string `json:",omitempty"` + Capabilities *[]NodeCapability `json:",omitempty"` } // DerpMagicIP is a fake WireGuard endpoint IP address that means to @@ -2257,6 +2567,8 @@ type PeerChange struct { // Mnemonic: 3.3.40 are numbers above the keys D, E, R, P. const DerpMagicIP = "127.3.3.40" +var DerpMagicIPAddr = netip.MustParseAddr(DerpMagicIP) + // EarlyNoise is the early payload that's sent over Noise but before the HTTP/2 // handshake when connecting to the coordination server. // diff --git a/vendor/tailscale.com/tailcfg/tailcfg_clone.go b/vendor/tailscale.com/tailcfg/tailcfg_clone.go index 9d72124b4b..6a2292149a 100644 --- a/vendor/tailscale.com/tailcfg/tailcfg_clone.go +++ b/vendor/tailscale.com/tailcfg/tailcfg_clone.go @@ -6,12 +6,14 @@ package tailcfg import ( + "maps" "net/netip" "time" "tailscale.com/types/dnstype" "tailscale.com/types/key" "tailscale.com/types/opt" + "tailscale.com/types/ptr" "tailscale.com/types/structs" "tailscale.com/types/tkatype" ) @@ -34,7 +36,6 @@ var _UserCloneNeedsRegeneration = User(struct { LoginName string DisplayName string ProfilePicURL string - Domain string Logins []LoginID Created time.Time }{}) @@ -55,17 +56,29 @@ func (src *Node) Clone() *Node { dst.Tags = append(src.Tags[:0:0], src.Tags...) dst.PrimaryRoutes = append(src.PrimaryRoutes[:0:0], src.PrimaryRoutes...) if dst.LastSeen != nil { - dst.LastSeen = new(time.Time) - *dst.LastSeen = *src.LastSeen + dst.LastSeen = ptr.To(*src.LastSeen) } if dst.Online != nil { - dst.Online = new(bool) - *dst.Online = *src.Online + dst.Online = ptr.To(*src.Online) } dst.Capabilities = append(src.Capabilities[:0:0], src.Capabilities...) + if dst.CapMap != nil { + dst.CapMap = map[NodeCapability][]RawMessage{} + for k := range src.CapMap { + dst.CapMap[k] = append([]RawMessage{}, src.CapMap[k]...) + } + } if dst.SelfNodeV4MasqAddrForThisPeer != nil { - dst.SelfNodeV4MasqAddrForThisPeer = new(netip.Addr) - *dst.SelfNodeV4MasqAddrForThisPeer = *src.SelfNodeV4MasqAddrForThisPeer + dst.SelfNodeV4MasqAddrForThisPeer = ptr.To(*src.SelfNodeV4MasqAddrForThisPeer) + } + if dst.SelfNodeV6MasqAddrForThisPeer != nil { + dst.SelfNodeV6MasqAddrForThisPeer = ptr.To(*src.SelfNodeV6MasqAddrForThisPeer) + } + if src.ExitNodeDNSResolvers != nil { + dst.ExitNodeDNSResolvers = make([]*dnstype.Resolver, len(src.ExitNodeDNSResolvers)) + for i := range dst.ExitNodeDNSResolvers { + dst.ExitNodeDNSResolvers[i] = src.ExitNodeDNSResolvers[i].Clone() + } } return dst } @@ -93,9 +106,9 @@ var _NodeCloneNeedsRegeneration = Node(struct { PrimaryRoutes []netip.Prefix LastSeen *time.Time Online *bool - KeepAlive bool MachineAuthorized bool - Capabilities []string + Capabilities []NodeCapability + CapMap NodeCapMap UnsignedPeerAPIOnly bool ComputedName string computedHostIfDifferent string @@ -103,7 +116,9 @@ var _NodeCloneNeedsRegeneration = Node(struct { DataPlaneAuditLogID string Expired bool SelfNodeV4MasqAddrForThisPeer *netip.Addr + SelfNodeV6MasqAddrForThisPeer *netip.Addr IsWireGuardOnly bool + ExitNodeDNSResolvers []*dnstype.Resolver }{}) // Clone makes a deep copy of Hostinfo. @@ -119,6 +134,9 @@ func (src *Hostinfo) Clone() *Hostinfo { dst.Services = append(src.Services[:0:0], src.Services...) dst.NetInfo = src.NetInfo.Clone() dst.SSH_HostKeys = append(src.SSH_HostKeys[:0:0], src.SSH_HostKeys...) + if dst.Location != nil { + dst.Location = ptr.To(*src.Location) + } return dst } @@ -157,6 +175,7 @@ var _HostinfoCloneNeedsRegeneration = Hostinfo(struct { Cloud string Userspace opt.Bool UserspaceRouter opt.Bool + Location *Location }{}) // Clone makes a deep copy of NetInfo. @@ -167,12 +186,7 @@ func (src *NetInfo) Clone() *NetInfo { } dst := new(NetInfo) *dst = *src - if dst.DERPLatency != nil { - dst.DERPLatency = map[string]float64{} - for k, v := range src.DERPLatency { - dst.DERPLatency[k] = v - } - } + dst.DERPLatency = maps.Clone(src.DERPLatency) return dst } @@ -191,6 +205,7 @@ var _NetInfoCloneNeedsRegeneration = NetInfo(struct { PreferredDERP int LinkType string DERPLatency map[string]float64 + FirewallMode string }{}) // Clone makes a deep copy of Login. @@ -212,7 +227,6 @@ var _LoginCloneNeedsRegeneration = Login(struct { LoginName string DisplayName string ProfilePicURL string - Domain string }{}) // Clone makes a deep copy of DNSConfig. @@ -223,9 +237,11 @@ func (src *DNSConfig) Clone() *DNSConfig { } dst := new(DNSConfig) *dst = *src - dst.Resolvers = make([]*dnstype.Resolver, len(src.Resolvers)) - for i := range dst.Resolvers { - dst.Resolvers[i] = src.Resolvers[i].Clone() + if src.Resolvers != nil { + dst.Resolvers = make([]*dnstype.Resolver, len(src.Resolvers)) + for i := range dst.Resolvers { + dst.Resolvers[i] = src.Resolvers[i].Clone() + } } if dst.Routes != nil { dst.Routes = map[string][]*dnstype.Resolver{} @@ -233,9 +249,11 @@ func (src *DNSConfig) Clone() *DNSConfig { dst.Routes[k] = append([]*dnstype.Resolver{}, src.Routes[k]...) } } - dst.FallbackResolvers = make([]*dnstype.Resolver, len(src.FallbackResolvers)) - for i := range dst.FallbackResolvers { - dst.FallbackResolvers[i] = src.FallbackResolvers[i].Clone() + if src.FallbackResolvers != nil { + dst.FallbackResolvers = make([]*dnstype.Resolver, len(src.FallbackResolvers)) + for i := range dst.FallbackResolvers { + dst.FallbackResolvers[i] = src.FallbackResolvers[i].Clone() + } } dst.Domains = append(src.Domains[:0:0], src.Domains...) dst.Nameservers = append(src.Nameservers[:0:0], src.Nameservers...) @@ -256,6 +274,7 @@ var _DNSConfigCloneNeedsRegeneration = DNSConfig(struct { CertDomains []string ExtraRecords []DNSRecord ExitNodeFilteredSet []string + TempCorpIssue13969 string }{}) // Clone makes a deep copy of RegisterResponse. @@ -282,6 +301,84 @@ var _RegisterResponseCloneNeedsRegeneration = RegisterResponse(struct { Error string }{}) +// Clone makes a deep copy of RegisterResponseAuth. +// The result aliases no memory with the original. +func (src *RegisterResponseAuth) Clone() *RegisterResponseAuth { + if src == nil { + return nil + } + dst := new(RegisterResponseAuth) + *dst = *src + if dst.Oauth2Token != nil { + dst.Oauth2Token = ptr.To(*src.Oauth2Token) + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _RegisterResponseAuthCloneNeedsRegeneration = RegisterResponseAuth(struct { + _ structs.Incomparable + Provider string + LoginName string + Oauth2Token *Oauth2Token + AuthKey string +}{}) + +// Clone makes a deep copy of RegisterRequest. +// The result aliases no memory with the original. +func (src *RegisterRequest) Clone() *RegisterRequest { + if src == nil { + return nil + } + dst := new(RegisterRequest) + *dst = *src + dst.Auth = *src.Auth.Clone() + dst.Hostinfo = src.Hostinfo.Clone() + dst.NodeKeySignature = append(src.NodeKeySignature[:0:0], src.NodeKeySignature...) + if dst.Timestamp != nil { + dst.Timestamp = ptr.To(*src.Timestamp) + } + dst.DeviceCert = append(src.DeviceCert[:0:0], src.DeviceCert...) + dst.Signature = append(src.Signature[:0:0], src.Signature...) + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _RegisterRequestCloneNeedsRegeneration = RegisterRequest(struct { + _ structs.Incomparable + Version CapabilityVersion + NodeKey key.NodePublic + OldNodeKey key.NodePublic + NLKey key.NLPublic + Auth RegisterResponseAuth + Expiry time.Time + Followup string + Hostinfo *Hostinfo + Ephemeral bool + NodeKeySignature tkatype.MarshaledSignature + SignatureType SignatureType + Timestamp *time.Time + DeviceCert []byte + Signature []byte +}{}) + +// Clone makes a deep copy of DERPHomeParams. +// The result aliases no memory with the original. +func (src *DERPHomeParams) Clone() *DERPHomeParams { + if src == nil { + return nil + } + dst := new(DERPHomeParams) + *dst = *src + dst.RegionScore = maps.Clone(src.RegionScore) + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _DERPHomeParamsCloneNeedsRegeneration = DERPHomeParams(struct { + RegionScore map[int]float64 +}{}) + // Clone makes a deep copy of DERPRegion. // The result aliases no memory with the original. func (src *DERPRegion) Clone() *DERPRegion { @@ -290,9 +387,11 @@ func (src *DERPRegion) Clone() *DERPRegion { } dst := new(DERPRegion) *dst = *src - dst.Nodes = make([]*DERPNode, len(src.Nodes)) - for i := range dst.Nodes { - dst.Nodes[i] = src.Nodes[i].Clone() + if src.Nodes != nil { + dst.Nodes = make([]*DERPNode, len(src.Nodes)) + for i := range dst.Nodes { + dst.Nodes[i] = src.Nodes[i].Clone() + } } return dst } @@ -314,6 +413,7 @@ func (src *DERPMap) Clone() *DERPMap { } dst := new(DERPMap) *dst = *src + dst.HomeParams = src.HomeParams.Clone() if dst.Regions != nil { dst.Regions = map[int]*DERPRegion{} for k, v := range src.Regions { @@ -325,6 +425,7 @@ func (src *DERPMap) Clone() *DERPMap { // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _DERPMapCloneNeedsRegeneration = DERPMap(struct { + HomeParams *DERPHomeParams Regions map[int]*DERPRegion OmitDefaultRegions bool }{}) @@ -365,19 +466,15 @@ func (src *SSHRule) Clone() *SSHRule { dst := new(SSHRule) *dst = *src if dst.RuleExpires != nil { - dst.RuleExpires = new(time.Time) - *dst.RuleExpires = *src.RuleExpires - } - dst.Principals = make([]*SSHPrincipal, len(src.Principals)) - for i := range dst.Principals { - dst.Principals[i] = src.Principals[i].Clone() + dst.RuleExpires = ptr.To(*src.RuleExpires) } - if dst.SSHUsers != nil { - dst.SSHUsers = map[string]string{} - for k, v := range src.SSHUsers { - dst.SSHUsers[k] = v + if src.Principals != nil { + dst.Principals = make([]*SSHPrincipal, len(src.Principals)) + for i := range dst.Principals { + dst.Principals[i] = src.Principals[i].Clone() } } + dst.SSHUsers = maps.Clone(src.SSHUsers) dst.Action = src.Action.Clone() return dst } @@ -400,8 +497,7 @@ func (src *SSHAction) Clone() *SSHAction { *dst = *src dst.Recorders = append(src.Recorders[:0:0], src.Recorders...) if dst.OnRecordingFailure != nil { - dst.OnRecordingFailure = new(SSHRecorderFailureAction) - *dst.OnRecordingFailure = *src.OnRecordingFailure + dst.OnRecordingFailure = ptr.To(*src.OnRecordingFailure) } return dst } @@ -458,9 +554,51 @@ var _ControlDialPlanCloneNeedsRegeneration = ControlDialPlan(struct { Candidates []ControlIPCandidate }{}) +// Clone makes a deep copy of Location. +// The result aliases no memory with the original. +func (src *Location) Clone() *Location { + if src == nil { + return nil + } + dst := new(Location) + *dst = *src + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _LocationCloneNeedsRegeneration = Location(struct { + Country string + CountryCode string + City string + CityCode string + Priority int +}{}) + +// Clone makes a deep copy of UserProfile. +// The result aliases no memory with the original. +func (src *UserProfile) Clone() *UserProfile { + if src == nil { + return nil + } + dst := new(UserProfile) + *dst = *src + dst.Groups = append(src.Groups[:0:0], src.Groups...) + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _UserProfileCloneNeedsRegeneration = UserProfile(struct { + ID UserID + LoginName string + DisplayName string + ProfilePicURL string + Roles emptyStructJSONSlice + Groups []string +}{}) + // Clone duplicates src into dst and reports whether it succeeded. // To succeed, must be of types <*T, *T> or <*T, **T>, -// where T is one of User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan. +// where T is one of User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,RegisterResponseAuth,RegisterRequest,DERPHomeParams,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan,Location,UserProfile. func Clone(dst, src any) bool { switch src := src.(type) { case *User: @@ -526,6 +664,33 @@ func Clone(dst, src any) bool { *dst = src.Clone() return true } + case *RegisterResponseAuth: + switch dst := dst.(type) { + case *RegisterResponseAuth: + *dst = *src.Clone() + return true + case **RegisterResponseAuth: + *dst = src.Clone() + return true + } + case *RegisterRequest: + switch dst := dst.(type) { + case *RegisterRequest: + *dst = *src.Clone() + return true + case **RegisterRequest: + *dst = src.Clone() + return true + } + case *DERPHomeParams: + switch dst := dst.(type) { + case *DERPHomeParams: + *dst = *src.Clone() + return true + case **DERPHomeParams: + *dst = src.Clone() + return true + } case *DERPRegion: switch dst := dst.(type) { case *DERPRegion: @@ -589,6 +754,24 @@ func Clone(dst, src any) bool { *dst = src.Clone() return true } + case *Location: + switch dst := dst.(type) { + case *Location: + *dst = *src.Clone() + return true + case **Location: + *dst = src.Clone() + return true + } + case *UserProfile: + switch dst := dst.(type) { + case *UserProfile: + *dst = *src.Clone() + return true + case **UserProfile: + *dst = src.Clone() + return true + } } return false } diff --git a/vendor/tailscale.com/tailcfg/tailcfg_view.go b/vendor/tailscale.com/tailcfg/tailcfg_view.go index 9c195da1cb..4a51f03f70 100644 --- a/vendor/tailscale.com/tailcfg/tailcfg_view.go +++ b/vendor/tailscale.com/tailcfg/tailcfg_view.go @@ -11,7 +11,6 @@ import ( "net/netip" "time" - "go4.org/mem" "tailscale.com/types/dnstype" "tailscale.com/types/key" "tailscale.com/types/opt" @@ -20,7 +19,7 @@ import ( "tailscale.com/types/views" ) -//go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type=User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan +//go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type=User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,RegisterResponseAuth,RegisterRequest,DERPHomeParams,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan,Location,UserProfile // View returns a readonly view of User. func (p *User) View() UserView { @@ -71,7 +70,6 @@ func (v UserView) ID() UserID { return v.ж.ID } func (v UserView) LoginName() string { return v.ж.LoginName } func (v UserView) DisplayName() string { return v.ж.DisplayName } func (v UserView) ProfilePicURL() string { return v.ж.ProfilePicURL } -func (v UserView) Domain() string { return v.ж.Domain } func (v UserView) Logins() views.Slice[LoginID] { return views.SliceOf(v.ж.Logins) } func (v UserView) Created() time.Time { return v.ж.Created } @@ -81,7 +79,6 @@ var _UserViewNeedsRegeneration = User(struct { LoginName string DisplayName string ProfilePicURL string - Domain string Logins []LoginID Created time.Time }{}) @@ -131,27 +128,27 @@ func (v *NodeView) UnmarshalJSON(b []byte) error { return nil } -func (v NodeView) ID() NodeID { return v.ж.ID } -func (v NodeView) StableID() StableNodeID { return v.ж.StableID } -func (v NodeView) Name() string { return v.ж.Name } -func (v NodeView) User() UserID { return v.ж.User } -func (v NodeView) Sharer() UserID { return v.ж.Sharer } -func (v NodeView) Key() key.NodePublic { return v.ж.Key } -func (v NodeView) KeyExpiry() time.Time { return v.ж.KeyExpiry } -func (v NodeView) KeySignature() mem.RO { return mem.B(v.ж.KeySignature) } -func (v NodeView) Machine() key.MachinePublic { return v.ж.Machine } -func (v NodeView) DiscoKey() key.DiscoPublic { return v.ж.DiscoKey } -func (v NodeView) Addresses() views.IPPrefixSlice { return views.IPPrefixSliceOf(v.ж.Addresses) } -func (v NodeView) AllowedIPs() views.IPPrefixSlice { return views.IPPrefixSliceOf(v.ж.AllowedIPs) } -func (v NodeView) Endpoints() views.Slice[string] { return views.SliceOf(v.ж.Endpoints) } -func (v NodeView) DERP() string { return v.ж.DERP } -func (v NodeView) Hostinfo() HostinfoView { return v.ж.Hostinfo } -func (v NodeView) Created() time.Time { return v.ж.Created } -func (v NodeView) Cap() CapabilityVersion { return v.ж.Cap } -func (v NodeView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) } -func (v NodeView) PrimaryRoutes() views.IPPrefixSlice { - return views.IPPrefixSliceOf(v.ж.PrimaryRoutes) -} +func (v NodeView) ID() NodeID { return v.ж.ID } +func (v NodeView) StableID() StableNodeID { return v.ж.StableID } +func (v NodeView) Name() string { return v.ж.Name } +func (v NodeView) User() UserID { return v.ж.User } +func (v NodeView) Sharer() UserID { return v.ж.Sharer } +func (v NodeView) Key() key.NodePublic { return v.ж.Key } +func (v NodeView) KeyExpiry() time.Time { return v.ж.KeyExpiry } +func (v NodeView) KeySignature() views.ByteSlice[tkatype.MarshaledSignature] { + return views.ByteSliceOf(v.ж.KeySignature) +} +func (v NodeView) Machine() key.MachinePublic { return v.ж.Machine } +func (v NodeView) DiscoKey() key.DiscoPublic { return v.ж.DiscoKey } +func (v NodeView) Addresses() views.Slice[netip.Prefix] { return views.SliceOf(v.ж.Addresses) } +func (v NodeView) AllowedIPs() views.Slice[netip.Prefix] { return views.SliceOf(v.ж.AllowedIPs) } +func (v NodeView) Endpoints() views.Slice[string] { return views.SliceOf(v.ж.Endpoints) } +func (v NodeView) DERP() string { return v.ж.DERP } +func (v NodeView) Hostinfo() HostinfoView { return v.ж.Hostinfo } +func (v NodeView) Created() time.Time { return v.ж.Created } +func (v NodeView) Cap() CapabilityVersion { return v.ж.Cap } +func (v NodeView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) } +func (v NodeView) PrimaryRoutes() views.Slice[netip.Prefix] { return views.SliceOf(v.ж.PrimaryRoutes) } func (v NodeView) LastSeen() *time.Time { if v.ж.LastSeen == nil { return nil @@ -168,14 +165,19 @@ func (v NodeView) Online() *bool { return &x } -func (v NodeView) KeepAlive() bool { return v.ж.KeepAlive } -func (v NodeView) MachineAuthorized() bool { return v.ж.MachineAuthorized } -func (v NodeView) Capabilities() views.Slice[string] { return views.SliceOf(v.ж.Capabilities) } -func (v NodeView) UnsignedPeerAPIOnly() bool { return v.ж.UnsignedPeerAPIOnly } -func (v NodeView) ComputedName() string { return v.ж.ComputedName } -func (v NodeView) ComputedNameWithHost() string { return v.ж.ComputedNameWithHost } -func (v NodeView) DataPlaneAuditLogID() string { return v.ж.DataPlaneAuditLogID } -func (v NodeView) Expired() bool { return v.ж.Expired } +func (v NodeView) MachineAuthorized() bool { return v.ж.MachineAuthorized } +func (v NodeView) Capabilities() views.Slice[NodeCapability] { return views.SliceOf(v.ж.Capabilities) } + +func (v NodeView) CapMap() views.MapFn[NodeCapability, []RawMessage, views.Slice[RawMessage]] { + return views.MapFnOf(v.ж.CapMap, func(t []RawMessage) views.Slice[RawMessage] { + return views.SliceOf(t) + }) +} +func (v NodeView) UnsignedPeerAPIOnly() bool { return v.ж.UnsignedPeerAPIOnly } +func (v NodeView) ComputedName() string { return v.ж.ComputedName } +func (v NodeView) ComputedNameWithHost() string { return v.ж.ComputedNameWithHost } +func (v NodeView) DataPlaneAuditLogID() string { return v.ж.DataPlaneAuditLogID } +func (v NodeView) Expired() bool { return v.ж.Expired } func (v NodeView) SelfNodeV4MasqAddrForThisPeer() *netip.Addr { if v.ж.SelfNodeV4MasqAddrForThisPeer == nil { return nil @@ -184,7 +186,18 @@ func (v NodeView) SelfNodeV4MasqAddrForThisPeer() *netip.Addr { return &x } -func (v NodeView) IsWireGuardOnly() bool { return v.ж.IsWireGuardOnly } +func (v NodeView) SelfNodeV6MasqAddrForThisPeer() *netip.Addr { + if v.ж.SelfNodeV6MasqAddrForThisPeer == nil { + return nil + } + x := *v.ж.SelfNodeV6MasqAddrForThisPeer + return &x +} + +func (v NodeView) IsWireGuardOnly() bool { return v.ж.IsWireGuardOnly } +func (v NodeView) ExitNodeDNSResolvers() views.SliceView[*dnstype.Resolver, dnstype.ResolverView] { + return views.SliceOfViews[*dnstype.Resolver, dnstype.ResolverView](v.ж.ExitNodeDNSResolvers) +} func (v NodeView) Equal(v2 NodeView) bool { return v.ж.Equal(v2.ж) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. @@ -210,9 +223,9 @@ var _NodeViewNeedsRegeneration = Node(struct { PrimaryRoutes []netip.Prefix LastSeen *time.Time Online *bool - KeepAlive bool MachineAuthorized bool - Capabilities []string + Capabilities []NodeCapability + CapMap NodeCapMap UnsignedPeerAPIOnly bool ComputedName string computedHostIfDifferent string @@ -220,7 +233,9 @@ var _NodeViewNeedsRegeneration = Node(struct { DataPlaneAuditLogID string Expired bool SelfNodeV4MasqAddrForThisPeer *netip.Addr + SelfNodeV6MasqAddrForThisPeer *netip.Addr IsWireGuardOnly bool + ExitNodeDNSResolvers []*dnstype.Resolver }{}) // View returns a readonly view of Hostinfo. @@ -268,42 +283,48 @@ func (v *HostinfoView) UnmarshalJSON(b []byte) error { return nil } -func (v HostinfoView) IPNVersion() string { return v.ж.IPNVersion } -func (v HostinfoView) FrontendLogID() string { return v.ж.FrontendLogID } -func (v HostinfoView) BackendLogID() string { return v.ж.BackendLogID } -func (v HostinfoView) OS() string { return v.ж.OS } -func (v HostinfoView) OSVersion() string { return v.ж.OSVersion } -func (v HostinfoView) Container() opt.Bool { return v.ж.Container } -func (v HostinfoView) Env() string { return v.ж.Env } -func (v HostinfoView) Distro() string { return v.ж.Distro } -func (v HostinfoView) DistroVersion() string { return v.ж.DistroVersion } -func (v HostinfoView) DistroCodeName() string { return v.ж.DistroCodeName } -func (v HostinfoView) App() string { return v.ж.App } -func (v HostinfoView) Desktop() opt.Bool { return v.ж.Desktop } -func (v HostinfoView) Package() string { return v.ж.Package } -func (v HostinfoView) DeviceModel() string { return v.ж.DeviceModel } -func (v HostinfoView) PushDeviceToken() string { return v.ж.PushDeviceToken } -func (v HostinfoView) Hostname() string { return v.ж.Hostname } -func (v HostinfoView) ShieldsUp() bool { return v.ж.ShieldsUp } -func (v HostinfoView) ShareeNode() bool { return v.ж.ShareeNode } -func (v HostinfoView) NoLogsNoSupport() bool { return v.ж.NoLogsNoSupport } -func (v HostinfoView) WireIngress() bool { return v.ж.WireIngress } -func (v HostinfoView) AllowsUpdate() bool { return v.ж.AllowsUpdate } -func (v HostinfoView) Machine() string { return v.ж.Machine } -func (v HostinfoView) GoArch() string { return v.ж.GoArch } -func (v HostinfoView) GoArchVar() string { return v.ж.GoArchVar } -func (v HostinfoView) GoVersion() string { return v.ж.GoVersion } -func (v HostinfoView) RoutableIPs() views.IPPrefixSlice { - return views.IPPrefixSliceOf(v.ж.RoutableIPs) -} -func (v HostinfoView) RequestTags() views.Slice[string] { return views.SliceOf(v.ж.RequestTags) } -func (v HostinfoView) Services() views.Slice[Service] { return views.SliceOf(v.ж.Services) } -func (v HostinfoView) NetInfo() NetInfoView { return v.ж.NetInfo.View() } -func (v HostinfoView) SSH_HostKeys() views.Slice[string] { return views.SliceOf(v.ж.SSH_HostKeys) } -func (v HostinfoView) Cloud() string { return v.ж.Cloud } -func (v HostinfoView) Userspace() opt.Bool { return v.ж.Userspace } -func (v HostinfoView) UserspaceRouter() opt.Bool { return v.ж.UserspaceRouter } -func (v HostinfoView) Equal(v2 HostinfoView) bool { return v.ж.Equal(v2.ж) } +func (v HostinfoView) IPNVersion() string { return v.ж.IPNVersion } +func (v HostinfoView) FrontendLogID() string { return v.ж.FrontendLogID } +func (v HostinfoView) BackendLogID() string { return v.ж.BackendLogID } +func (v HostinfoView) OS() string { return v.ж.OS } +func (v HostinfoView) OSVersion() string { return v.ж.OSVersion } +func (v HostinfoView) Container() opt.Bool { return v.ж.Container } +func (v HostinfoView) Env() string { return v.ж.Env } +func (v HostinfoView) Distro() string { return v.ж.Distro } +func (v HostinfoView) DistroVersion() string { return v.ж.DistroVersion } +func (v HostinfoView) DistroCodeName() string { return v.ж.DistroCodeName } +func (v HostinfoView) App() string { return v.ж.App } +func (v HostinfoView) Desktop() opt.Bool { return v.ж.Desktop } +func (v HostinfoView) Package() string { return v.ж.Package } +func (v HostinfoView) DeviceModel() string { return v.ж.DeviceModel } +func (v HostinfoView) PushDeviceToken() string { return v.ж.PushDeviceToken } +func (v HostinfoView) Hostname() string { return v.ж.Hostname } +func (v HostinfoView) ShieldsUp() bool { return v.ж.ShieldsUp } +func (v HostinfoView) ShareeNode() bool { return v.ж.ShareeNode } +func (v HostinfoView) NoLogsNoSupport() bool { return v.ж.NoLogsNoSupport } +func (v HostinfoView) WireIngress() bool { return v.ж.WireIngress } +func (v HostinfoView) AllowsUpdate() bool { return v.ж.AllowsUpdate } +func (v HostinfoView) Machine() string { return v.ж.Machine } +func (v HostinfoView) GoArch() string { return v.ж.GoArch } +func (v HostinfoView) GoArchVar() string { return v.ж.GoArchVar } +func (v HostinfoView) GoVersion() string { return v.ж.GoVersion } +func (v HostinfoView) RoutableIPs() views.Slice[netip.Prefix] { return views.SliceOf(v.ж.RoutableIPs) } +func (v HostinfoView) RequestTags() views.Slice[string] { return views.SliceOf(v.ж.RequestTags) } +func (v HostinfoView) Services() views.Slice[Service] { return views.SliceOf(v.ж.Services) } +func (v HostinfoView) NetInfo() NetInfoView { return v.ж.NetInfo.View() } +func (v HostinfoView) SSH_HostKeys() views.Slice[string] { return views.SliceOf(v.ж.SSH_HostKeys) } +func (v HostinfoView) Cloud() string { return v.ж.Cloud } +func (v HostinfoView) Userspace() opt.Bool { return v.ж.Userspace } +func (v HostinfoView) UserspaceRouter() opt.Bool { return v.ж.UserspaceRouter } +func (v HostinfoView) Location() *Location { + if v.ж.Location == nil { + return nil + } + x := *v.ж.Location + return &x +} + +func (v HostinfoView) Equal(v2 HostinfoView) bool { return v.ж.Equal(v2.ж) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _HostinfoViewNeedsRegeneration = Hostinfo(struct { @@ -340,6 +361,7 @@ var _HostinfoViewNeedsRegeneration = Hostinfo(struct { Cloud string Userspace opt.Bool UserspaceRouter opt.Bool + Location *Location }{}) // View returns a readonly view of NetInfo. @@ -401,6 +423,7 @@ func (v NetInfoView) PreferredDERP() int { return v.ж.PreferredDER func (v NetInfoView) LinkType() string { return v.ж.LinkType } func (v NetInfoView) DERPLatency() views.Map[string, float64] { return views.MapOf(v.ж.DERPLatency) } +func (v NetInfoView) FirewallMode() string { return v.ж.FirewallMode } func (v NetInfoView) String() string { return v.ж.String() } // A compilation failure here means this code must be regenerated, with the command at the top of this file. @@ -418,6 +441,7 @@ var _NetInfoViewNeedsRegeneration = NetInfo(struct { PreferredDERP int LinkType string DERPLatency map[string]float64 + FirewallMode string }{}) // View returns a readonly view of Login. @@ -470,7 +494,6 @@ func (v LoginView) Provider() string { return v.ж.Provider } func (v LoginView) LoginName() string { return v.ж.LoginName } func (v LoginView) DisplayName() string { return v.ж.DisplayName } func (v LoginView) ProfilePicURL() string { return v.ж.ProfilePicURL } -func (v LoginView) Domain() string { return v.ж.Domain } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _LoginViewNeedsRegeneration = Login(struct { @@ -480,7 +503,6 @@ var _LoginViewNeedsRegeneration = Login(struct { LoginName string DisplayName string ProfilePicURL string - Domain string }{}) // View returns a readonly view of DNSConfig. @@ -548,6 +570,7 @@ func (v DNSConfigView) ExtraRecords() views.Slice[DNSRecord] { return views.Slic func (v DNSConfigView) ExitNodeFilteredSet() views.Slice[string] { return views.SliceOf(v.ж.ExitNodeFilteredSet) } +func (v DNSConfigView) TempCorpIssue13969() string { return v.ж.TempCorpIssue13969 } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _DNSConfigViewNeedsRegeneration = DNSConfig(struct { @@ -560,6 +583,7 @@ var _DNSConfigViewNeedsRegeneration = DNSConfig(struct { CertDomains []string ExtraRecords []DNSRecord ExitNodeFilteredSet []string + TempCorpIssue13969 string }{}) // View returns a readonly view of RegisterResponse. @@ -607,13 +631,15 @@ func (v *RegisterResponseView) UnmarshalJSON(b []byte) error { return nil } -func (v RegisterResponseView) User() UserView { return v.ж.User.View() } -func (v RegisterResponseView) Login() Login { return v.ж.Login } -func (v RegisterResponseView) NodeKeyExpired() bool { return v.ж.NodeKeyExpired } -func (v RegisterResponseView) MachineAuthorized() bool { return v.ж.MachineAuthorized } -func (v RegisterResponseView) AuthURL() string { return v.ж.AuthURL } -func (v RegisterResponseView) NodeKeySignature() mem.RO { return mem.B(v.ж.NodeKeySignature) } -func (v RegisterResponseView) Error() string { return v.ж.Error } +func (v RegisterResponseView) User() UserView { return v.ж.User.View() } +func (v RegisterResponseView) Login() Login { return v.ж.Login } +func (v RegisterResponseView) NodeKeyExpired() bool { return v.ж.NodeKeyExpired } +func (v RegisterResponseView) MachineAuthorized() bool { return v.ж.MachineAuthorized } +func (v RegisterResponseView) AuthURL() string { return v.ж.AuthURL } +func (v RegisterResponseView) NodeKeySignature() views.ByteSlice[tkatype.MarshaledSignature] { + return views.ByteSliceOf(v.ж.NodeKeySignature) +} +func (v RegisterResponseView) Error() string { return v.ж.Error } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _RegisterResponseViewNeedsRegeneration = RegisterResponse(struct { @@ -626,6 +652,218 @@ var _RegisterResponseViewNeedsRegeneration = RegisterResponse(struct { Error string }{}) +// View returns a readonly view of RegisterResponseAuth. +func (p *RegisterResponseAuth) View() RegisterResponseAuthView { + return RegisterResponseAuthView{ж: p} +} + +// RegisterResponseAuthView provides a read-only view over RegisterResponseAuth. +// +// Its methods should only be called if `Valid()` returns true. +type RegisterResponseAuthView struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *RegisterResponseAuth +} + +// Valid reports whether underlying value is non-nil. +func (v RegisterResponseAuthView) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v RegisterResponseAuthView) AsStruct() *RegisterResponseAuth { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +func (v RegisterResponseAuthView) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } + +func (v *RegisterResponseAuthView) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x RegisterResponseAuth + if err := json.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v RegisterResponseAuthView) Provider() string { return v.ж.Provider } +func (v RegisterResponseAuthView) LoginName() string { return v.ж.LoginName } +func (v RegisterResponseAuthView) Oauth2Token() *Oauth2Token { + if v.ж.Oauth2Token == nil { + return nil + } + x := *v.ж.Oauth2Token + return &x +} + +func (v RegisterResponseAuthView) AuthKey() string { return v.ж.AuthKey } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _RegisterResponseAuthViewNeedsRegeneration = RegisterResponseAuth(struct { + _ structs.Incomparable + Provider string + LoginName string + Oauth2Token *Oauth2Token + AuthKey string +}{}) + +// View returns a readonly view of RegisterRequest. +func (p *RegisterRequest) View() RegisterRequestView { + return RegisterRequestView{ж: p} +} + +// RegisterRequestView provides a read-only view over RegisterRequest. +// +// Its methods should only be called if `Valid()` returns true. +type RegisterRequestView struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *RegisterRequest +} + +// Valid reports whether underlying value is non-nil. +func (v RegisterRequestView) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v RegisterRequestView) AsStruct() *RegisterRequest { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +func (v RegisterRequestView) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } + +func (v *RegisterRequestView) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x RegisterRequest + if err := json.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v RegisterRequestView) Version() CapabilityVersion { return v.ж.Version } +func (v RegisterRequestView) NodeKey() key.NodePublic { return v.ж.NodeKey } +func (v RegisterRequestView) OldNodeKey() key.NodePublic { return v.ж.OldNodeKey } +func (v RegisterRequestView) NLKey() key.NLPublic { return v.ж.NLKey } +func (v RegisterRequestView) Auth() RegisterResponseAuthView { return v.ж.Auth.View() } +func (v RegisterRequestView) Expiry() time.Time { return v.ж.Expiry } +func (v RegisterRequestView) Followup() string { return v.ж.Followup } +func (v RegisterRequestView) Hostinfo() HostinfoView { return v.ж.Hostinfo.View() } +func (v RegisterRequestView) Ephemeral() bool { return v.ж.Ephemeral } +func (v RegisterRequestView) NodeKeySignature() views.ByteSlice[tkatype.MarshaledSignature] { + return views.ByteSliceOf(v.ж.NodeKeySignature) +} +func (v RegisterRequestView) SignatureType() SignatureType { return v.ж.SignatureType } +func (v RegisterRequestView) Timestamp() *time.Time { + if v.ж.Timestamp == nil { + return nil + } + x := *v.ж.Timestamp + return &x +} + +func (v RegisterRequestView) DeviceCert() views.ByteSlice[[]byte] { + return views.ByteSliceOf(v.ж.DeviceCert) +} +func (v RegisterRequestView) Signature() views.ByteSlice[[]byte] { + return views.ByteSliceOf(v.ж.Signature) +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _RegisterRequestViewNeedsRegeneration = RegisterRequest(struct { + _ structs.Incomparable + Version CapabilityVersion + NodeKey key.NodePublic + OldNodeKey key.NodePublic + NLKey key.NLPublic + Auth RegisterResponseAuth + Expiry time.Time + Followup string + Hostinfo *Hostinfo + Ephemeral bool + NodeKeySignature tkatype.MarshaledSignature + SignatureType SignatureType + Timestamp *time.Time + DeviceCert []byte + Signature []byte +}{}) + +// View returns a readonly view of DERPHomeParams. +func (p *DERPHomeParams) View() DERPHomeParamsView { + return DERPHomeParamsView{ж: p} +} + +// DERPHomeParamsView provides a read-only view over DERPHomeParams. +// +// Its methods should only be called if `Valid()` returns true. +type DERPHomeParamsView struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *DERPHomeParams +} + +// Valid reports whether underlying value is non-nil. +func (v DERPHomeParamsView) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v DERPHomeParamsView) AsStruct() *DERPHomeParams { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +func (v DERPHomeParamsView) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } + +func (v *DERPHomeParamsView) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x DERPHomeParams + if err := json.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v DERPHomeParamsView) RegionScore() views.Map[int, float64] { + return views.MapOf(v.ж.RegionScore) +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _DERPHomeParamsViewNeedsRegeneration = DERPHomeParams(struct { + RegionScore map[int]float64 +}{}) + // View returns a readonly view of DERPRegion. func (p *DERPRegion) View() DERPRegionView { return DERPRegionView{ж: p} @@ -733,6 +971,8 @@ func (v *DERPMapView) UnmarshalJSON(b []byte) error { return nil } +func (v DERPMapView) HomeParams() DERPHomeParamsView { return v.ж.HomeParams.View() } + func (v DERPMapView) Regions() views.MapFn[int, *DERPRegion, DERPRegionView] { return views.MapFnOf(v.ж.Regions, func(t *DERPRegion) DERPRegionView { return t.View() @@ -742,6 +982,7 @@ func (v DERPMapView) OmitDefaultRegions() bool { return v.ж.OmitDefaultRegions // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _DERPMapViewNeedsRegeneration = DERPMap(struct { + HomeParams *DERPHomeParams Regions map[int]*DERPRegion OmitDefaultRegions bool }{}) @@ -1077,3 +1318,126 @@ func (v ControlDialPlanView) Candidates() views.Slice[ControlIPCandidate] { var _ControlDialPlanViewNeedsRegeneration = ControlDialPlan(struct { Candidates []ControlIPCandidate }{}) + +// View returns a readonly view of Location. +func (p *Location) View() LocationView { + return LocationView{ж: p} +} + +// LocationView provides a read-only view over Location. +// +// Its methods should only be called if `Valid()` returns true. +type LocationView struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *Location +} + +// Valid reports whether underlying value is non-nil. +func (v LocationView) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v LocationView) AsStruct() *Location { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +func (v LocationView) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } + +func (v *LocationView) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x Location + if err := json.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v LocationView) Country() string { return v.ж.Country } +func (v LocationView) CountryCode() string { return v.ж.CountryCode } +func (v LocationView) City() string { return v.ж.City } +func (v LocationView) CityCode() string { return v.ж.CityCode } +func (v LocationView) Priority() int { return v.ж.Priority } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _LocationViewNeedsRegeneration = Location(struct { + Country string + CountryCode string + City string + CityCode string + Priority int +}{}) + +// View returns a readonly view of UserProfile. +func (p *UserProfile) View() UserProfileView { + return UserProfileView{ж: p} +} + +// UserProfileView provides a read-only view over UserProfile. +// +// Its methods should only be called if `Valid()` returns true. +type UserProfileView struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *UserProfile +} + +// Valid reports whether underlying value is non-nil. +func (v UserProfileView) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v UserProfileView) AsStruct() *UserProfile { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +func (v UserProfileView) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } + +func (v *UserProfileView) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x UserProfile + if err := json.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v UserProfileView) ID() UserID { return v.ж.ID } +func (v UserProfileView) LoginName() string { return v.ж.LoginName } +func (v UserProfileView) DisplayName() string { return v.ж.DisplayName } +func (v UserProfileView) ProfilePicURL() string { return v.ж.ProfilePicURL } +func (v UserProfileView) Roles() emptyStructJSONSlice { return v.ж.Roles } +func (v UserProfileView) Groups() views.Slice[string] { return views.SliceOf(v.ж.Groups) } +func (v UserProfileView) Equal(v2 UserProfileView) bool { return v.ж.Equal(v2.ж) } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _UserProfileViewNeedsRegeneration = UserProfile(struct { + ID UserID + LoginName string + DisplayName string + ProfilePicURL string + Roles emptyStructJSONSlice + Groups []string +}{}) diff --git a/vendor/tailscale.com/tempfork/heap/heap.go b/vendor/tailscale.com/tempfork/heap/heap.go new file mode 100644 index 0000000000..3dfab492ad --- /dev/null +++ b/vendor/tailscale.com/tempfork/heap/heap.go @@ -0,0 +1,121 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package heap provides heap operations for any type that implements +// heap.Interface. A heap is a tree with the property that each node is the +// minimum-valued node in its subtree. +// +// The minimum element in the tree is the root, at index 0. +// +// A heap is a common way to implement a priority queue. To build a priority +// queue, implement the Heap interface with the (negative) priority as the +// ordering for the Less method, so Push adds items while Pop removes the +// highest-priority item from the queue. The Examples include such an +// implementation; the file example_pq_test.go has the complete source. +// +// This package is a copy of the Go standard library's +// container/heap, but using generics. +package heap + +import "sort" + +// The Interface type describes the requirements +// for a type using the routines in this package. +// Any type that implements it may be used as a +// min-heap with the following invariants (established after +// Init has been called or if the data is empty or sorted): +// +// !h.Less(j, i) for 0 <= i < h.Len() and 2*i+1 <= j <= 2*i+2 and j < h.Len() +// +// Note that Push and Pop in this interface are for package heap's +// implementation to call. To add and remove things from the heap, +// use heap.Push and heap.Pop. +type Interface[V any] interface { + sort.Interface + Push(x V) // add x as element Len() + Pop() V // remove and return element Len() - 1. +} + +// Init establishes the heap invariants required by the other routines in this package. +// Init is idempotent with respect to the heap invariants +// and may be called whenever the heap invariants may have been invalidated. +// The complexity is O(n) where n = h.Len(). +func Init[V any](h Interface[V]) { + // heapify + n := h.Len() + for i := n/2 - 1; i >= 0; i-- { + down(h, i, n) + } +} + +// Push pushes the element x onto the heap. +// The complexity is O(log n) where n = h.Len(). +func Push[V any](h Interface[V], x V) { + h.Push(x) + up(h, h.Len()-1) +} + +// Pop removes and returns the minimum element (according to Less) from the heap. +// The complexity is O(log n) where n = h.Len(). +// Pop is equivalent to Remove(h, 0). +func Pop[V any](h Interface[V]) V { + n := h.Len() - 1 + h.Swap(0, n) + down(h, 0, n) + return h.Pop() +} + +// Remove removes and returns the element at index i from the heap. +// The complexity is O(log n) where n = h.Len(). +func Remove[V any](h Interface[V], i int) V { + n := h.Len() - 1 + if n != i { + h.Swap(i, n) + if !down(h, i, n) { + up(h, i) + } + } + return h.Pop() +} + +// Fix re-establishes the heap ordering after the element at index i has changed its value. +// Changing the value of the element at index i and then calling Fix is equivalent to, +// but less expensive than, calling Remove(h, i) followed by a Push of the new value. +// The complexity is O(log n) where n = h.Len(). +func Fix[V any](h Interface[V], i int) { + if !down(h, i, h.Len()) { + up(h, i) + } +} + +func up[V any](h Interface[V], j int) { + for { + i := (j - 1) / 2 // parent + if i == j || !h.Less(j, i) { + break + } + h.Swap(i, j) + j = i + } +} + +func down[V any](h Interface[V], i0, n int) bool { + i := i0 + for { + j1 := 2*i + 1 + if j1 >= n || j1 < 0 { // j1 < 0 after int overflow + break + } + j := j1 // left child + if j2 := j1 + 1; j2 < n && h.Less(j2, j1) { + j = j2 // = 2*i + 2 // right child + } + if !h.Less(j, i) { + break + } + h.Swap(i, j) + i = j + } + return i > i0 +} diff --git a/vendor/tailscale.com/tka/aum.go b/vendor/tailscale.com/tka/aum.go index d21ca199bf..d1f0793983 100644 --- a/vendor/tailscale.com/tka/aum.go +++ b/vendor/tailscale.com/tka/aum.go @@ -9,10 +9,12 @@ import ( "encoding/base32" "errors" "fmt" + "slices" "github.com/fxamacker/cbor/v2" "golang.org/x/crypto/blake2s" "tailscale.com/types/tkatype" + "tailscale.com/util/set" ) // AUMHash represents the BLAKE2s digest of an Authority Update Message (AUM). @@ -37,11 +39,22 @@ func (h *AUMHash) UnmarshalText(text []byte) error { return nil } +// TODO(https://go.dev/issue/53693): Use base32.Encoding.AppendEncode instead. +func base32AppendEncode(enc *base32.Encoding, dst, src []byte) []byte { + n := enc.EncodedLen(len(src)) + dst = slices.Grow(dst, n) + enc.Encode(dst[len(dst):][:n], src) + return dst[:len(dst)+n] +} + +// AppendText implements encoding.TextAppender. +func (h AUMHash) AppendText(b []byte) ([]byte, error) { + return base32AppendEncode(base32StdNoPad, b, h[:]), nil +} + // MarshalText implements encoding.TextMarshaler. func (h AUMHash) MarshalText() ([]byte, error) { - b := make([]byte, base32StdNoPad.EncodedLen(len(h))) - base32StdNoPad.Encode(b, h[:]) - return b, nil + return h.AppendText(nil) } // IsZero returns true if the hash is the empty value. @@ -314,7 +327,7 @@ func (a *AUM) Weight(state State) uint { // Despite the wire encoding being []byte, all KeyIDs are // 32 bytes. As such, we use that as the key for the map, // because map keys cannot be slices. - seenKeys := make(map[[32]byte]struct{}, 6) + seenKeys := make(set.Set[[32]byte], 6) for _, sig := range a.Signatures { if len(sig.KeyID) != 32 { panic("unexpected: keyIDs are 32 bytes") @@ -332,12 +345,12 @@ func (a *AUM) Weight(state State) uint { } panic(err) } - if _, seen := seenKeys[keyID]; seen { + if seenKeys.Contains(keyID) { continue } weight += key.Votes - seenKeys[keyID] = struct{}{} + seenKeys.Add(keyID) } return weight diff --git a/vendor/tailscale.com/tka/key.go b/vendor/tailscale.com/tka/key.go index 619fc5f4a6..07736795d8 100644 --- a/vendor/tailscale.com/tka/key.go +++ b/vendor/tailscale.com/tka/key.go @@ -145,6 +145,9 @@ func signatureVerify(s *tkatype.Signature, aumDigest tkatype.AUMSigHash, key Key // so we should use the public contained in the state machine. switch key.Kind { case Key25519: + if len(key.Public) != ed25519.PublicKeySize { + return fmt.Errorf("ed25519 key has wrong length: %d", len(key.Public)) + } if ed25519consensus.Verify(ed25519.PublicKey(key.Public), aumDigest[:], s.Signature) { return nil } diff --git a/vendor/tailscale.com/tka/sig.go b/vendor/tailscale.com/tka/sig.go index 8376889c32..212f5431e1 100644 --- a/vendor/tailscale.com/tka/sig.go +++ b/vendor/tailscale.com/tka/sig.go @@ -225,6 +225,9 @@ func (s *NodeKeySignature) verifySignature(nodeKey key.NodePublic, verificationK if !ok { return errors.New("missing rotation key") } + if len(verifyPub) != ed25519.PublicKeySize { + return fmt.Errorf("bad rotation key length: %d", len(verifyPub)) + } if !ed25519.Verify(ed25519.PublicKey(verifyPub[:]), sigHash[:], s.Signature) { return errors.New("invalid signature") } @@ -249,6 +252,9 @@ func (s *NodeKeySignature) verifySignature(nodeKey key.NodePublic, verificationK } switch verificationKey.Kind { case Key25519: + if len(verificationKey.Public) != ed25519.PublicKeySize { + return fmt.Errorf("ed25519 key has wrong length: %d", len(verificationKey.Public)) + } if ed25519consensus.Verify(ed25519.PublicKey(verificationKey.Public), sigHash[:], s.Signature) { return nil } diff --git a/vendor/tailscale.com/tka/tka.go b/vendor/tailscale.com/tka/tka.go index 3903097339..61bee804b3 100644 --- a/vendor/tailscale.com/tka/tka.go +++ b/vendor/tailscale.com/tka/tka.go @@ -9,12 +9,12 @@ import ( "errors" "fmt" "os" - "reflect" "sort" "github.com/fxamacker/cbor/v2" "tailscale.com/types/key" "tailscale.com/types/tkatype" + "tailscale.com/util/set" ) // Strict settings for the CBOR decoder. @@ -29,6 +29,9 @@ var cborDecOpts = cbor.DecOptions{ MaxMapPairs: 1024, } +// Arbitrarily chosen limit on scanning AUM trees. +const maxScanIterations = 2000 + // Authority is a Tailnet Key Authority. This type is the main coupling // point to the rest of the tailscale client. // @@ -182,17 +185,6 @@ func advanceByPrimary(state State, candidates []AUM) (next *AUM, out State, err aum := pickNextAUM(state, candidates) - // TODO(tom): Remove this before GA, this is just a correctness check during implementation. - // Post-GA, we want clients to not error if they dont recognize additional fields in State. - if aum.MessageKind == AUMCheckpoint { - dupe := state - dupe.LastAUMHash = nil - // aum.State is non-nil (see aum.StaticValidate). - if !reflect.DeepEqual(dupe, *aum.State) { - return nil, State{}, errors.New("checkpoint includes changes not represented in earlier AUMs") - } - } - if state, err = state.applyVerifiedAUM(aum); err != nil { return nil, State{}, fmt.Errorf("advancing state: %v", err) } @@ -269,13 +261,13 @@ func computeStateAt(storage Chonk, maxIter int, wantHash AUMHash) (State, error) var ( curs = topAUM state State - path = make(map[AUMHash]struct{}, 32) // 32 chosen arbitrarily. + path = make(set.Set[AUMHash], 32) // 32 chosen arbitrarily. ) for i := 0; true; i++ { if i > maxIter { return State{}, fmt.Errorf("iteration limit exceeded (%d)", maxIter) } - path[curs.Hash()] = struct{}{} + path.Add(curs.Hash()) // Checkpoints encapsulate the state at that point, dope. if curs.MessageKind == AUMCheckpoint { @@ -316,7 +308,7 @@ func computeStateAt(storage Chonk, maxIter int, wantHash AUMHash) (State, error) // such, we use a custom advancer here. advancer := func(state State, candidates []AUM) (next *AUM, out State, err error) { for _, c := range candidates { - if _, inPath := path[c.Hash()]; inPath { + if path.Contains(c.Hash()) { if state, err = state.applyVerifiedAUM(c); err != nil { return nil, State{}, fmt.Errorf("advancing state: %v", err) } @@ -334,8 +326,7 @@ func computeStateAt(storage Chonk, maxIter int, wantHash AUMHash) (State, error) // as we've already iterated through them above so they must exist, // but we check anyway to be super duper sure. if err == nil && *state.LastAUMHash != wantHash { - // TODO(tom): Error instead of panic before GA. - panic("unexpected fastForward outcome") + return State{}, errors.New("unexpected fastForward outcome") } return state, err } @@ -484,7 +475,7 @@ func Open(storage Chonk) (*Authority, error) { return nil, fmt.Errorf("reading last ancestor: %v", err) } - c, err := computeActiveChain(storage, a, 2000) + c, err := computeActiveChain(storage, a, maxScanIterations) if err != nil { return nil, fmt.Errorf("active chain: %v", err) } @@ -617,7 +608,7 @@ func (a *Authority) InformIdempotent(storage Chonk, updates []AUM) (Authority, e state, hasState := stateAt[parent] var err error if !hasState { - if state, err = computeStateAt(storage, 2000, parent); err != nil { + if state, err = computeStateAt(storage, maxScanIterations, parent); err != nil { return Authority{}, fmt.Errorf("update %d computing state: %v", i, err) } stateAt[parent] = state @@ -652,7 +643,7 @@ func (a *Authority) InformIdempotent(storage Chonk, updates []AUM) (Authority, e } oldestAncestor := a.oldestAncestor.Hash() - c, err := computeActiveChain(storage, &oldestAncestor, 2000) + c, err := computeActiveChain(storage, &oldestAncestor, maxScanIterations) if err != nil { return Authority{}, fmt.Errorf("recomputing active chain: %v", err) } @@ -734,3 +725,115 @@ func (a *Authority) Compact(storage CompactableChonk, o CompactionOptions) error a.oldestAncestor = ancestor return nil } + +// findParentForRewrite finds the parent AUM to use when rewriting state to +// retroactively remove trust in the specified keys. +func (a *Authority) findParentForRewrite(storage Chonk, removeKeys []tkatype.KeyID, ourKey tkatype.KeyID) (AUMHash, error) { + cursor := a.Head() + + for { + if cursor == a.oldestAncestor.Hash() { + // We've reached as far back in our history as we can, + // so we have to rewrite from here. + break + } + + aum, err := storage.AUM(cursor) + if err != nil { + return AUMHash{}, fmt.Errorf("reading AUM %v: %w", cursor, err) + } + + // An ideal rewrite parent trusts none of the keys to be removed. + state, err := computeStateAt(storage, maxScanIterations, cursor) + if err != nil { + return AUMHash{}, fmt.Errorf("computing state for %v: %w", cursor, err) + } + keyTrusted := false + for _, key := range removeKeys { + if _, err := state.GetKey(key); err == nil { + keyTrusted = true + } + } + if !keyTrusted { + // Success: the revoked keys are not trusted! + // Lets check that our key was trusted to ensure + // we can sign a fork from here. + if _, err := state.GetKey(ourKey); err == nil { + break + } + } + + parent, hasParent := aum.Parent() + if !hasParent { + // This is the genesis AUM, so we have to rewrite from here. + break + } + cursor = parent + } + + return cursor, nil +} + +// MakeRetroactiveRevocation generates a forking update which revokes the specified keys, in +// such a manner that any malicious use of those keys is erased. +// +// If forkFrom is specified, it is used as the parent AUM to fork from. If the zero value, +// the parent AUM is determined automatically. +// +// The generated AUM must be signed with more signatures than the sum of key votes that +// were compromised, before being consumed by tka.Authority methods. +func (a *Authority) MakeRetroactiveRevocation(storage Chonk, removeKeys []tkatype.KeyID, ourKey tkatype.KeyID, forkFrom AUMHash) (*AUM, error) { + var parent AUMHash + if forkFrom == (AUMHash{}) { + // Make sure at least one of the recovery keys is currently trusted. + foundKey := false + for _, k := range removeKeys { + if _, err := a.state.GetKey(k); err == nil { + foundKey = true + break + } + } + if !foundKey { + return nil, errors.New("no provided key is currently trusted") + } + + p, err := a.findParentForRewrite(storage, removeKeys, ourKey) + if err != nil { + return nil, fmt.Errorf("finding parent: %v", err) + } + parent = p + } else { + parent = forkFrom + } + + // Construct the new state where the revoked keys are no longer trusted. + state := a.state.Clone() + for _, keyToRevoke := range removeKeys { + idx := -1 + for i := range state.Keys { + keyID, err := state.Keys[i].ID() + if err != nil { + return nil, fmt.Errorf("computing keyID: %v", err) + } + if bytes.Equal(keyToRevoke, keyID) { + idx = i + break + } + } + if idx >= 0 { + state.Keys = append(state.Keys[:idx], state.Keys[idx+1:]...) + } + } + if len(state.Keys) == 0 { + return nil, errors.New("cannot revoke all trusted keys") + } + state.LastAUMHash = nil // checkpoints can't specify a LastAUMHash + + forkingAUM := &AUM{ + MessageKind: AUMCheckpoint, + State: &state, + PrevAUMHash: parent[:], + } + + return forkingAUM, forkingAUM.StaticValidate() +} diff --git a/vendor/tailscale.com/tsd/tsd.go b/vendor/tailscale.com/tsd/tsd.go index ffa245bc9c..2a233e51dc 100644 --- a/vendor/tailscale.com/tsd/tsd.go +++ b/vendor/tailscale.com/tsd/tsd.go @@ -21,11 +21,14 @@ import ( "fmt" "reflect" + "tailscale.com/control/controlknobs" "tailscale.com/ipn" "tailscale.com/net/dns" "tailscale.com/net/netmon" "tailscale.com/net/tsdial" "tailscale.com/net/tstun" + "tailscale.com/proxymap" + "tailscale.com/types/netmap" "tailscale.com/wgengine" "tailscale.com/wgengine/magicsock" "tailscale.com/wgengine/router" @@ -42,6 +45,17 @@ type System struct { Router SubSystem[router.Router] Tun SubSystem[*tstun.Wrapper] StateStore SubSystem[ipn.StateStore] + Netstack SubSystem[NetstackImpl] // actually a *netstack.Impl + + controlKnobs controlknobs.Knobs + proxyMap proxymap.Mapper +} + +// NetstackImpl is the interface that *netstack.Impl implements. +// It's an interface for circular dependency reasons: netstack.Impl +// references LocalBackend, and LocalBackend has a tsd.System. +type NetstackImpl interface { + UpdateNetstackIPs(*netmap.NetworkMap) } // Set is a convenience method to set a subsystem value. @@ -65,6 +79,8 @@ func (s *System) Set(v any) { s.MagicSock.Set(v) case ipn.StateStore: s.StateStore.Set(v) + case NetstackImpl: + s.Netstack.Set(v) default: panic(fmt.Sprintf("unknown type %T", v)) } @@ -85,6 +101,16 @@ func (s *System) IsNetstack() bool { return name == tstun.FakeTUNName } +// ControlKnobs returns the control knobs for this node. +func (s *System) ControlKnobs() *controlknobs.Knobs { + return &s.controlKnobs +} + +// ProxyMapper returns the ephemeral ip:port mapper. +func (s *System) ProxyMapper() *proxymap.Mapper { + return &s.proxyMap +} + // SubSystem represents some subsystem of the Tailscale node daemon. // // A subsystem can be set to a value, and then later retrieved. A subsystem diff --git a/vendor/tailscale.com/tsnet/tsnet.go b/vendor/tailscale.com/tsnet/tsnet.go index 9d7fd7e581..9e6d34f4c5 100644 --- a/vendor/tailscale.com/tsnet/tsnet.go +++ b/vendor/tailscale.com/tsnet/tsnet.go @@ -12,7 +12,6 @@ import ( "crypto/tls" "encoding/hex" "errors" - "flag" "fmt" "io" "log" @@ -23,12 +22,12 @@ import ( "os" "path/filepath" "runtime" + "slices" "strconv" "strings" "sync" "time" - "golang.org/x/exp/slices" "tailscale.com/client/tailscale" "tailscale.com/control/controlclient" "tailscale.com/envknob" @@ -52,16 +51,16 @@ import ( "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/nettype" + "tailscale.com/util/clientmetric" "tailscale.com/util/mak" + "tailscale.com/util/testenv" "tailscale.com/wgengine" "tailscale.com/wgengine/netstack" ) -func inTest() bool { return flag.Lookup("test.v") != nil } - // Server is an embedded Tailscale server. // -// Its exported fields may be changed until the first call to Listen. +// Its exported fields may be changed until the first method call. type Server struct { // Dir specifies the name of the directory to use for // state. If empty, a directory is selected automatically @@ -108,6 +107,11 @@ type Server struct { // If empty, the Tailscale default is used. ControlURL string + // Port is the UDP port to listen on for WireGuard and peer-to-peer + // traffic. If zero, a port is automatically selected. Leave this + // field at zero unless you know what you are doing. + Port uint16 + getCertForTesting func(*tls.ClientHelloInfo) (*tls.Certificate, error) initOnce sync.Once @@ -345,15 +349,6 @@ func (s *Server) Close() error { } }() - if _, isMemStore := s.Store.(*mem.Store); isMemStore && s.Ephemeral && s.lb != nil { - wg.Add(1) - go func() { - defer wg.Done() - // Perform a best-effort logout. - s.lb.LogoutSync(ctx) - }() - } - if s.netstack != nil { s.netstack.Close() s.netstack = nil @@ -413,7 +408,9 @@ func (s *Server) TailscaleIPs() (ip4, ip6 netip.Addr) { if nm == nil { return } - for _, addr := range nm.Addresses { + addrs := nm.GetAddresses() + for i := range addrs.LenIter() { + addr := addrs.At(i) ip := addr.Addr() if ip.Is6() { ip6 = ip @@ -502,10 +499,11 @@ func (s *Server) start() (reterr error) { sys := new(tsd.System) s.dialer = &tsdial.Dialer{Logf: logf} // mutated below (before used) eng, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{ - ListenPort: 0, + ListenPort: s.Port, NetMon: s.netMon, Dialer: s.dialer, SetSubsystem: sys.Set, + ControlKnobs: sys.ControlKnobs(), }) if err != nil { return err @@ -513,10 +511,11 @@ func (s *Server) start() (reterr error) { closePool.add(s.dialer) sys.Set(eng) - ns, err := netstack.Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), s.dialer, sys.DNSManager.Get()) + ns, err := netstack.Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), s.dialer, sys.DNSManager.Get(), sys.ProxyMapper()) if err != nil { return fmt.Errorf("netstack.Create: %w", err) } + sys.Set(ns) ns.ProcessLocalIPs = true ns.GetTCPHandlerForFlow = s.getTCPHandlerForFlow ns.GetUDPHandlerForFlow = s.getUDPHandlerForFlow @@ -555,9 +554,6 @@ func (s *Server) start() (reterr error) { return fmt.Errorf("failed to start netstack: %w", err) } closePool.addFunc(func() { s.lb.Shutdown() }) - lb.SetDecompressor(func() (controlclient.Decompressor, error) { - return smallzstd.NewDecoder(nil) - }) prefs := ipn.NewPrefs() prefs.Hostname = s.hostname prefs.WantRunning = true @@ -600,7 +596,7 @@ func (s *Server) start() (reterr error) { } func (s *Server) startLogger(closePool *closeOnErrorPool) error { - if inTest() { + if testenv.InTest() { return nil } cfgPath := filepath.Join(s.rootPath, "tailscaled.log.conf") @@ -636,7 +632,8 @@ func (s *Server) startLogger(closePool *closeOnErrorPool) error { } return w }, - HTTPC: &http.Client{Transport: logpolicy.NewLogtailTransport(logtail.DefaultHost, s.netMon)}, + HTTPC: &http.Client{Transport: logpolicy.NewLogtailTransport(logtail.DefaultHost, s.netMon, s.logf)}, + MetricsDelta: clientmetric.EncodeLogTailMetricsDelta, } s.logtail = logtail.NewLogger(c, s.logf) closePool.addFunc(func() { s.logtail.Shutdown(context.Background()) }) @@ -925,7 +922,11 @@ func (s *Server) ListenFunnel(network, addr string, opts ...FunnelOption) (net.L if err != nil { return nil, err } - if err := ipn.CheckFunnelAccess(uint16(port), st.Self.Capabilities); err != nil { + // TODO(sonia,tailscale/corp#10577): We may want to use the interactive enable + // flow here instead of CheckFunnelAccess to allow the user to turn on Funnel + // if not already on. Specifically when running from a terminal. + // See cli.serveEnv.verifyFunnelEnabled. + if err := ipn.CheckFunnelAccess(uint16(port), st.Self); err != nil { return nil, err } diff --git a/vendor/tailscale.com/tstime/mono/mono.go b/vendor/tailscale.com/tstime/mono/mono.go index 6b02bb41a7..260e02b0fb 100644 --- a/vendor/tailscale.com/tstime/mono/mono.go +++ b/vendor/tailscale.com/tstime/mono/mono.go @@ -16,7 +16,6 @@ import ( "fmt" "sync/atomic" "time" - _ "unsafe" // for go:linkname ) // Time is the number of nanoseconds elapsed since an unspecified reference start time. @@ -29,7 +28,7 @@ func Now() Time { // The corresponding package time expression never does, if the wall clock is correct. // Preserve this correspondence by increasing the "base" monotonic clock by a fair amount. const baseOffset int64 = 1 << 55 // approximately 10,000 hours in nanoseconds - return Time(now() + baseOffset) + return Time(int64(time.Since(baseWall)) + baseOffset) } // Since returns the time elapsed since t. @@ -72,9 +71,6 @@ func (t *Time) LoadAtomic() Time { return Time(atomic.LoadInt64((*int64)(t))) } -//go:linkname now runtime.nanotime1 -func now() int64 - // baseWall and baseMono are a pair of almost-identical times used to correlate a Time with a wall time. var ( baseWall time.Time @@ -104,7 +100,8 @@ func (t Time) WallTime() time.Time { // MarshalJSON formats t for JSON as if it were a time.Time. // We format Time this way for backwards-compatibility. -// This is best-effort only. Time does not survive a MarshalJSON/UnmarshalJSON round trip unchanged. +// Time does not survive a MarshalJSON/UnmarshalJSON round trip unchanged +// across different invocations of the Go process. This is best-effort only. // Since t is a monotonic time, it can vary from the actual wall clock by arbitrary amounts. // Even in the best of circumstances, it may vary by a few milliseconds. func (t Time) MarshalJSON() ([]byte, error) { @@ -113,7 +110,8 @@ func (t Time) MarshalJSON() ([]byte, error) { } // UnmarshalJSON sets t according to data. -// This is best-effort only. Time does not survive a MarshalJSON/UnmarshalJSON round trip unchanged. +// Time does not survive a MarshalJSON/UnmarshalJSON round trip unchanged +// across different invocations of the Go process. This is best-effort only. func (t *Time) UnmarshalJSON(data []byte) error { var tt time.Time err := tt.UnmarshalJSON(data) @@ -124,6 +122,6 @@ func (t *Time) UnmarshalJSON(data []byte) error { *t = 0 return nil } - *t = Now().Add(-time.Since(tt)) + *t = baseMono.Add(tt.Sub(baseWall)) return nil } diff --git a/vendor/tailscale.com/tstime/tstime.go b/vendor/tailscale.com/tstime/tstime.go index cc76755cc5..69073639af 100644 --- a/vendor/tailscale.com/tstime/tstime.go +++ b/vendor/tailscale.com/tstime/tstime.go @@ -59,3 +59,87 @@ func Sleep(ctx context.Context, d time.Duration) bool { return true } } + +// Clock offers a subset of the functionality from the std/time package. +// Normally, applications will use the StdClock implementation that calls the +// appropriate std/time exported funcs. The advantage of using Clock is that +// tests can substitute a different implementation, allowing the test to control +// time precisely, something required for certain types of tests to be possible +// at all, speeds up execution by not needing to sleep, and can dramatically +// reduce the risk of flakes due to tests executing too slowly or quickly. +type Clock interface { + // Now returns the current time, as in time.Now. + Now() time.Time + // NewTimer returns a timer whose notion of the current time is controlled + // by this Clock. It follows the semantics of time.NewTimer as closely as + // possible but is adapted to return an interface, so the channel needs to + // be returned as well. + NewTimer(d time.Duration) (TimerController, <-chan time.Time) + // NewTicker returns a ticker whose notion of the current time is controlled + // by this Clock. It follows the semantics of time.NewTicker as closely as + // possible but is adapted to return an interface, so the channel needs to + // be returned as well. + NewTicker(d time.Duration) (TickerController, <-chan time.Time) + // AfterFunc returns a ticker whose notion of the current time is controlled + // by this Clock. When the ticker expires, it will call the provided func. + // It follows the semantics of time.AfterFunc. + AfterFunc(d time.Duration, f func()) TimerController + // Since returns the time elapsed since t. + // It follows the semantics of time.Since. + Since(t time.Time) time.Duration +} + +// TickerController offers the receivers of a time.Ticker to ensure +// compatibility with standard timers, but allows for the option of substituting +// a standard timer with something else for testing purposes. +type TickerController interface { + // Reset follows the same semantics as with time.Ticker.Reset. + Reset(d time.Duration) + // Stop follows the same semantics as with time.Ticker.Stop. + Stop() +} + +// TimerController offers the receivers of a time.Timer to ensure +// compatibility with standard timers, but allows for the option of substituting +// a standard timer with something else for testing purposes. +type TimerController interface { + // Reset follows the same semantics as with time.Timer.Reset. + Reset(d time.Duration) bool + // Stop follows the same semantics as with time.Timer.Stop. + Stop() bool +} + +// StdClock is a simple implementation of Clock using the relevant funcs in the +// std/time package. +type StdClock struct{} + +// Now calls time.Now. +func (StdClock) Now() time.Time { + return time.Now() +} + +// NewTimer calls time.NewTimer. As an interface does not allow for struct +// members and other packages cannot add receivers to another package, the +// channel is also returned because it would be otherwise inaccessible. +func (StdClock) NewTimer(d time.Duration) (TimerController, <-chan time.Time) { + t := time.NewTimer(d) + return t, t.C +} + +// NewTicker calls time.NewTicker. As an interface does not allow for struct +// members and other packages cannot add receivers to another package, the +// channel is also returned because it would be otherwise inaccessible. +func (StdClock) NewTicker(d time.Duration) (TickerController, <-chan time.Time) { + t := time.NewTicker(d) + return t, t.C +} + +// AfterFunc calls time.AfterFunc. +func (StdClock) AfterFunc(d time.Duration, f func()) TimerController { + return time.AfterFunc(d, f) +} + +// Since calls time.Since. +func (StdClock) Since(t time.Time) time.Duration { + return time.Since(t) +} diff --git a/vendor/tailscale.com/types/dnstype/dnstype.go b/vendor/tailscale.com/types/dnstype/dnstype.go index a5137fa792..ae3d1defc6 100644 --- a/vendor/tailscale.com/types/dnstype/dnstype.go +++ b/vendor/tailscale.com/types/dnstype/dnstype.go @@ -8,6 +8,7 @@ package dnstype import ( "net/netip" + "slices" ) // Resolver is the configuration for one DNS resolver. @@ -51,3 +52,15 @@ func (r *Resolver) IPPort() (ipp netip.AddrPort, ok bool) { } return } + +// Equal reports whether r and other are equal. +func (r *Resolver) Equal(other *Resolver) bool { + if r == nil || other == nil { + return r == other + } + if r == other { + return true + } + + return r.Addr == other.Addr && slices.Equal(r.BootstrapResolution, other.BootstrapResolution) +} diff --git a/vendor/tailscale.com/types/dnstype/dnstype_view.go b/vendor/tailscale.com/types/dnstype/dnstype_view.go index b8f1e0312f..c0e2b28ffb 100644 --- a/vendor/tailscale.com/types/dnstype/dnstype_view.go +++ b/vendor/tailscale.com/types/dnstype/dnstype_view.go @@ -64,6 +64,7 @@ func (v ResolverView) Addr() string { return v.ж.Addr } func (v ResolverView) BootstrapResolution() views.Slice[netip.Addr] { return views.SliceOf(v.ж.BootstrapResolution) } +func (v ResolverView) Equal(v2 ResolverView) bool { return v.ж.Equal(v2.ж) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _ResolverViewNeedsRegeneration = Resolver(struct { diff --git a/vendor/tailscale.com/types/key/chal.go b/vendor/tailscale.com/types/key/chal.go index 8b46f4e526..742ac5479e 100644 --- a/vendor/tailscale.com/types/key/chal.go +++ b/vendor/tailscale.com/types/key/chal.go @@ -72,9 +72,14 @@ func (k ChallengePublic) String() string { return string(bs) } +// AppendText implements encoding.TextAppender. +func (k ChallengePublic) AppendText(b []byte) ([]byte, error) { + return appendHexKey(b, chalPublicHexPrefix, k.k[:]), nil +} + // MarshalText implements encoding.TextMarshaler. func (k ChallengePublic) MarshalText() ([]byte, error) { - return toHex(k.k[:], chalPublicHexPrefix), nil + return k.AppendText(nil) } // UnmarshalText implements encoding.TextUnmarshaler. diff --git a/vendor/tailscale.com/types/key/disco.go b/vendor/tailscale.com/types/key/disco.go index 8f8ee5c886..14005b5067 100644 --- a/vendor/tailscale.com/types/key/disco.go +++ b/vendor/tailscale.com/types/key/disco.go @@ -127,9 +127,14 @@ func (k DiscoPublic) String() string { return string(bs) } +// AppendText implements encoding.TextAppender. +func (k DiscoPublic) AppendText(b []byte) ([]byte, error) { + return appendHexKey(b, discoPublicHexPrefix, k.k[:]), nil +} + // MarshalText implements encoding.TextMarshaler. func (k DiscoPublic) MarshalText() ([]byte, error) { - return toHex(k.k[:], discoPublicHexPrefix), nil + return k.AppendText(nil) } // MarshalText implements encoding.TextUnmarshaler. diff --git a/vendor/tailscale.com/types/key/machine.go b/vendor/tailscale.com/types/key/machine.go index 0a81e63a7b..a05f3cc1f5 100644 --- a/vendor/tailscale.com/types/key/machine.go +++ b/vendor/tailscale.com/types/key/machine.go @@ -67,9 +67,14 @@ func (k MachinePrivate) Public() MachinePublic { return ret } +// AppendText implements encoding.TextAppender. +func (k MachinePrivate) AppendText(b []byte) ([]byte, error) { + return appendHexKey(b, machinePrivateHexPrefix, k.k[:]), nil +} + // MarshalText implements encoding.TextMarshaler. func (k MachinePrivate) MarshalText() ([]byte, error) { - return toHex(k.k[:], machinePrivateHexPrefix), nil + return k.AppendText(nil) } // MarshalText implements encoding.TextUnmarshaler. @@ -243,9 +248,14 @@ func (k MachinePublic) String() string { return string(bs) } +// AppendText implements encoding.TextAppender. +func (k MachinePublic) AppendText(b []byte) ([]byte, error) { + return appendHexKey(b, machinePublicHexPrefix, k.k[:]), nil +} + // MarshalText implements encoding.TextMarshaler. func (k MachinePublic) MarshalText() ([]byte, error) { - return toHex(k.k[:], machinePublicHexPrefix), nil + return k.AppendText(nil) } // MarshalText implements encoding.TextUnmarshaler. diff --git a/vendor/tailscale.com/types/key/nl.go b/vendor/tailscale.com/types/key/nl.go index 1b70e21862..e0b4e5ca61 100644 --- a/vendor/tailscale.com/types/key/nl.go +++ b/vendor/tailscale.com/types/key/nl.go @@ -61,9 +61,14 @@ func (k *NLPrivate) UnmarshalText(b []byte) error { return parseHex(k.k[:], mem.B(b), mem.S(nlPrivateHexPrefix)) } +// AppendText implements encoding.TextAppender. +func (k NLPrivate) AppendText(b []byte) ([]byte, error) { + return appendHexKey(b, nlPrivateHexPrefix, k.k[:]), nil +} + // MarshalText implements encoding.TextMarshaler. func (k NLPrivate) MarshalText() ([]byte, error) { - return toHex(k.k[:], nlPrivateHexPrefix), nil + return k.AppendText(nil) } // Equal reports whether k and other are the same key. @@ -132,10 +137,15 @@ func (k *NLPublic) UnmarshalText(b []byte) error { return parseHex(k.k[:], mem.B(b), mem.S(nlPublicHexPrefix)) } +// AppendText implements encoding.TextAppender. +func (k NLPublic) AppendText(b []byte) ([]byte, error) { + return appendHexKey(b, nlPublicHexPrefix, k.k[:]), nil +} + // MarshalText implements encoding.TextMarshaler, emitting a // representation of the form nlpub:. func (k NLPublic) MarshalText() ([]byte, error) { - return toHex(k.k[:], nlPublicHexPrefix), nil + return k.AppendText(nil) } // CLIString returns a marshalled representation suitable for use @@ -143,7 +153,7 @@ func (k NLPublic) MarshalText() ([]byte, error) { // the nlpub: form emitted by MarshalText. Both forms can // be decoded by UnmarshalText. func (k NLPublic) CLIString() string { - return string(toHex(k.k[:], nlPublicHexPrefixCLI)) + return string(appendHexKey(nil, nlPublicHexPrefixCLI, k.k[:])) } // Verifier returns a ed25519.PublicKey that can be used to diff --git a/vendor/tailscale.com/types/key/node.go b/vendor/tailscale.com/types/key/node.go index a840572313..4cb7287986 100644 --- a/vendor/tailscale.com/types/key/node.go +++ b/vendor/tailscale.com/types/key/node.go @@ -103,9 +103,14 @@ func (k NodePrivate) Public() NodePublic { return ret } +// AppendText implements encoding.TextAppender. +func (k NodePrivate) AppendText(b []byte) ([]byte, error) { + return appendHexKey(b, nodePrivateHexPrefix, k.k[:]), nil +} + // MarshalText implements encoding.TextMarshaler. func (k NodePrivate) MarshalText() ([]byte, error) { - return toHex(k.k[:], nodePrivateHexPrefix), nil + return k.AppendText(nil) } // MarshalText implements encoding.TextUnmarshaler. @@ -163,6 +168,12 @@ func (p NodePublic) Shard() uint8 { return s ^ uint8(p.k[2]+p.k[12]) } +// Compare returns -1, 0, or 1, depending on whether p orders before p2, +// using bytes.Compare on the bytes of the public key. +func (p NodePublic) Compare(p2 NodePublic) int { + return bytes.Compare(p.k[:], p2.k[:]) +} + // ParseNodePublicUntyped parses an untyped 64-character hex value // as a NodePublic. // @@ -308,9 +319,14 @@ func (k NodePublic) String() string { return string(bs) } +// AppendText implements encoding.TextAppender. +func (k NodePublic) AppendText(b []byte) ([]byte, error) { + return appendHexKey(b, nodePublicHexPrefix, k.k[:]), nil +} + // MarshalText implements encoding.TextMarshaler. func (k NodePublic) MarshalText() ([]byte, error) { - return toHex(k.k[:], nodePublicHexPrefix), nil + return k.AppendText(nil) } // MarshalText implements encoding.TextUnmarshaler. diff --git a/vendor/tailscale.com/types/key/util.go b/vendor/tailscale.com/types/key/util.go index cf079be950..f20cb42792 100644 --- a/vendor/tailscale.com/types/key/util.go +++ b/vendor/tailscale.com/types/key/util.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "slices" "go4.org/mem" ) @@ -49,11 +50,19 @@ func clamp25519Private(b []byte) { b[31] = (b[31] & 127) | 64 } -func toHex(k []byte, prefix string) []byte { - ret := make([]byte, len(prefix)+len(k)*2) - copy(ret, prefix) - hex.Encode(ret[len(prefix):], k) - return ret +func appendHexKey(dst []byte, prefix string, key []byte) []byte { + dst = slices.Grow(dst, len(prefix)+hex.EncodedLen(len(key))) + dst = append(dst, prefix...) + dst = hexAppendEncode(dst, key) + return dst +} + +// TODO(https://go.dev/issue/53693): Use hex.AppendEncode instead. +func hexAppendEncode(dst, src []byte) []byte { + n := hex.EncodedLen(len(src)) + dst = slices.Grow(dst, n) + hex.Encode(dst[len(dst):][:n], src) + return dst[:len(dst)+n] } // parseHex decodes a key string of the form "" diff --git a/vendor/tailscale.com/types/logger/logger.go b/vendor/tailscale.com/types/logger/logger.go index 1df273b3e1..232679ba78 100644 --- a/vendor/tailscale.com/types/logger/logger.go +++ b/vendor/tailscale.com/types/logger/logger.go @@ -353,3 +353,39 @@ func LogfCloser(logf Logf) (newLogf Logf, close func()) { } return newLogf, close } + +// AsJSON returns a formatter that formats v as JSON. The value is suitable to +// passing to a regular %v printf argument. (%s is not required) +// +// If json.Marshal returns an error, the output is "%%!JSON-ERROR:" followed by +// the error string. +func AsJSON(v any) fmt.Formatter { + return asJSONResult{v} +} + +type asJSONResult struct{ v any } + +func (a asJSONResult) Format(s fmt.State, verb rune) { + v, err := json.Marshal(a.v) + if err != nil { + fmt.Fprintf(s, "%%!JSON-ERROR:%v", err) + return + } + s.Write(v) +} + +// TBLogger is the testing.TB subset needed by TestLogger. +type TBLogger interface { + Helper() + Logf(format string, args ...any) +} + +// TestLogger returns a logger that logs to tb.Logf +// with a prefix to make it easier to distinguish spam +// from explicit test failures. +func TestLogger(tb TBLogger) Logf { + return func(format string, args ...any) { + tb.Helper() + tb.Logf(" ... "+format, args...) + } +} diff --git a/vendor/tailscale.com/types/logger/rusage_stub.go b/vendor/tailscale.com/types/logger/rusage_stub.go index 5aba2066b5..f646f1e1ee 100644 --- a/vendor/tailscale.com/types/logger/rusage_stub.go +++ b/vendor/tailscale.com/types/logger/rusage_stub.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build windows || js || wasip1 +//go:build windows || wasm || plan9 || tamago package logger diff --git a/vendor/tailscale.com/types/logger/rusage_syscall.go b/vendor/tailscale.com/types/logger/rusage_syscall.go index 6378521b6c..2871b66c6b 100644 --- a/vendor/tailscale.com/types/logger/rusage_syscall.go +++ b/vendor/tailscale.com/types/logger/rusage_syscall.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !windows && !js && !wasip1 +//go:build !windows && !wasm && !plan9 && !tamago package logger diff --git a/vendor/tailscale.com/types/logid/id.go b/vendor/tailscale.com/types/logid/id.go index f3d705f182..ac4369a4e8 100644 --- a/vendor/tailscale.com/types/logid/id.go +++ b/vendor/tailscale.com/types/logid/id.go @@ -11,9 +11,8 @@ import ( "encoding/binary" "encoding/hex" "fmt" + "slices" "unicode/utf8" - - "golang.org/x/exp/slices" ) // PrivateID represents a log steam for writing. @@ -39,8 +38,12 @@ func ParsePrivateID(in string) (out PrivateID, err error) { return out, err } +func (id PrivateID) AppendText(b []byte) ([]byte, error) { + return hexAppendEncode(b, id[:]), nil +} + func (id PrivateID) MarshalText() ([]byte, error) { - return formatID(id), nil + return id.AppendText(nil) } func (id *PrivateID) UnmarshalText(in []byte) error { @@ -48,7 +51,7 @@ func (id *PrivateID) UnmarshalText(in []byte) error { } func (id PrivateID) String() string { - return string(formatID(id)) + return string(hexAppendEncode(nil, id[:])) } func (id PrivateID) IsZero() bool { @@ -71,8 +74,12 @@ func ParsePublicID(in string) (out PublicID, err error) { return out, err } +func (id PublicID) AppendText(b []byte) ([]byte, error) { + return hexAppendEncode(b, id[:]), nil +} + func (id PublicID) MarshalText() ([]byte, error) { - return formatID(id), nil + return id.AppendText(nil) } func (id *PublicID) UnmarshalText(in []byte) error { @@ -80,11 +87,15 @@ func (id *PublicID) UnmarshalText(in []byte) error { } func (id PublicID) String() string { - return string(formatID(id)) + return string(hexAppendEncode(nil, id[:])) } func (id1 PublicID) Less(id2 PublicID) bool { - return slices.Compare(id1[:], id2[:]) < 0 + return id1.Compare(id2) < 0 +} + +func (id1 PublicID) Compare(id2 PublicID) int { + return slices.Compare(id1[:], id2[:]) } func (id PublicID) IsZero() bool { @@ -95,10 +106,12 @@ func (id PublicID) Prefix64() uint64 { return binary.BigEndian.Uint64(id[:8]) } -func formatID(in [32]byte) []byte { - var hexArr [2 * len(in)]byte - hex.Encode(hexArr[:], in[:]) - return hexArr[:] +// TODO(https://go.dev/issue/53693): Use hex.AppendEncode instead. +func hexAppendEncode(dst, src []byte) []byte { + n := hex.EncodedLen(len(src)) + dst = slices.Grow(dst, n) + hex.Encode(dst[len(dst):][:n], src) + return dst[:len(dst)+n] } func parseID[Bytes []byte | string](funcName string, out *[32]byte, in Bytes) (err error) { diff --git a/vendor/tailscale.com/types/netmap/netmap.go b/vendor/tailscale.com/types/netmap/netmap.go index b3a9df8ba5..9e2c154066 100644 --- a/vendor/tailscale.com/types/netmap/netmap.go +++ b/vendor/tailscale.com/types/netmap/netmap.go @@ -8,7 +8,7 @@ import ( "encoding/json" "fmt" "net/netip" - "reflect" + "sort" "strings" "time" @@ -16,6 +16,7 @@ import ( "tailscale.com/tka" "tailscale.com/types/key" "tailscale.com/types/views" + "tailscale.com/util/cmpx" "tailscale.com/wgengine/filter" ) @@ -24,21 +25,19 @@ import ( // The fields should all be considered read-only. They might // alias parts of previous NetworkMap values. type NetworkMap struct { - // Core networking - - SelfNode *tailcfg.Node + SelfNode tailcfg.NodeView NodeKey key.NodePublic PrivateKey key.NodePrivate Expiry time.Time // Name is the DNS name assigned to this node. - Name string - Addresses []netip.Prefix // same as tailcfg.Node.Addresses (IP addresses of this Node directly) - MachineStatus tailcfg.MachineStatus - MachineKey key.MachinePublic - Peers []*tailcfg.Node // sorted by Node.ID - DNS tailcfg.DNSConfig - // TODO(maisem) : replace with View. - Hostinfo tailcfg.Hostinfo + // It is the MapResponse.Node.Name value and ends with a period. + Name string + + MachineKey key.MachinePublic + + Peers []tailcfg.NodeView // sorted by Node.ID + DNS tailcfg.DNSConfig + PacketFilter []filter.Match PacketFilterRules views.Slice[tailcfg.FilterRule] SSHPolicy *tailcfg.SSHPolicy // or nil, if not enabled/allowed @@ -53,9 +52,6 @@ type NetworkMap struct { // between updates and should not be modified. DERPMap *tailcfg.DERPMap - // Debug knobs from control server for debug or feature gating. - Debug *tailcfg.Debug - // ControlHealth are the list of health check problems for this // node from the perspective of the control plane. // If empty, there are no known problems from the control plane's @@ -70,10 +66,6 @@ type NetworkMap struct { // hash of the latest update message to tick through TKA). TKAHead tka.AUMHash - // ACLs - - User tailcfg.UserID - // Domain is the current Tailnet name. Domain string @@ -85,55 +77,121 @@ type NetworkMap struct { UserProfiles map[tailcfg.UserID]tailcfg.UserProfile } +// User returns nm.SelfNode.User if nm.SelfNode is non-nil, otherwise it returns +// 0. +func (nm *NetworkMap) User() tailcfg.UserID { + if nm.SelfNode.Valid() { + return nm.SelfNode.User() + } + return 0 +} + +// GetAddresses returns the self node's addresses, or the zero value +// if SelfNode is invalid. +func (nm *NetworkMap) GetAddresses() views.Slice[netip.Prefix] { + var zero views.Slice[netip.Prefix] + if !nm.SelfNode.Valid() { + return zero + } + return nm.SelfNode.Addresses() +} + // AnyPeersAdvertiseRoutes reports whether any peer is advertising non-exit node routes. func (nm *NetworkMap) AnyPeersAdvertiseRoutes() bool { for _, p := range nm.Peers { - if len(p.PrimaryRoutes) > 0 { + if p.PrimaryRoutes().Len() > 0 { return true } } return false } +// GetMachineStatus returns the MachineStatus of the local node. +func (nm *NetworkMap) GetMachineStatus() tailcfg.MachineStatus { + if !nm.SelfNode.Valid() { + return tailcfg.MachineUnknown + } + if nm.SelfNode.MachineAuthorized() { + return tailcfg.MachineAuthorized + } + return tailcfg.MachineUnauthorized +} + // PeerByTailscaleIP returns a peer's Node based on its Tailscale IP. // // If nm is nil or no peer is found, ok is false. -func (nm *NetworkMap) PeerByTailscaleIP(ip netip.Addr) (peer *tailcfg.Node, ok bool) { +func (nm *NetworkMap) PeerByTailscaleIP(ip netip.Addr) (peer tailcfg.NodeView, ok bool) { // TODO(bradfitz): if nm == nil { - return nil, false + return tailcfg.NodeView{}, false } for _, n := range nm.Peers { - for _, a := range n.Addresses { + ad := n.Addresses() + for i := 0; i < ad.Len(); i++ { + a := ad.At(i) if a.Addr() == ip { return n, true } } } - return nil, false + return tailcfg.NodeView{}, false } -// MagicDNSSuffix returns the domain's MagicDNS suffix (even if -// MagicDNS isn't necessarily in use). +// PeerIndexByNodeID returns the index of the peer with the given nodeID +// in nm.Peers, or -1 if nm is nil or not found. +// +// It assumes nm.Peers is sorted by Node.ID. +func (nm *NetworkMap) PeerIndexByNodeID(nodeID tailcfg.NodeID) int { + if nm == nil { + return -1 + } + idx, ok := sort.Find(len(nm.Peers), func(i int) int { + return cmpx.Compare(nodeID, nm.Peers[i].ID()) + }) + if !ok { + return -1 + } + return idx +} + +// MagicDNSSuffix returns the domain's MagicDNS suffix (even if MagicDNS isn't +// necessarily in use) of the provided Node.Name value. // // It will neither start nor end with a period. -func (nm *NetworkMap) MagicDNSSuffix() string { - name := strings.Trim(nm.Name, ".") +func MagicDNSSuffixOfNodeName(nodeName string) string { + name := strings.Trim(nodeName, ".") if _, rest, ok := strings.Cut(name, "."); ok { return rest } return name } +// MagicDNSSuffix returns the domain's MagicDNS suffix (even if +// MagicDNS isn't necessarily in use). +// +// It will neither start nor end with a period. +func (nm *NetworkMap) MagicDNSSuffix() string { + if nm == nil { + return "" + } + return MagicDNSSuffixOfNodeName(nm.Name) +} + // SelfCapabilities returns SelfNode.Capabilities if nm and nm.SelfNode are // non-nil. This is a method so we can use it in envknob/logknob without a // circular dependency. -func (nm *NetworkMap) SelfCapabilities() []string { - if nm == nil || nm.SelfNode == nil { - return nil +func (nm *NetworkMap) SelfCapabilities() views.Slice[tailcfg.NodeCapability] { + var zero views.Slice[tailcfg.NodeCapability] + if nm == nil || !nm.SelfNode.Valid() { + return zero } + out := nm.SelfNode.Capabilities().AsSlice() + nm.SelfNode.CapMap().Range(func(k tailcfg.NodeCapability, _ views.Slice[tailcfg.RawMessage]) (cont bool) { + out = append(out, k) + return true + }) - return nm.SelfNode.Capabilities + return views.SliceOf(out) } func (nm *NetworkMap) String() string { @@ -157,13 +215,13 @@ func (nm *NetworkMap) VeryConcise() string { } // PeerWithStableID finds and returns the peer associated to the inputted StableNodeID. -func (nm *NetworkMap) PeerWithStableID(pid tailcfg.StableNodeID) (_ *tailcfg.Node, ok bool) { +func (nm *NetworkMap) PeerWithStableID(pid tailcfg.StableNodeID) (_ tailcfg.NodeView, ok bool) { for _, p := range nm.Peers { - if p.StableID == pid { + if p.StableID() == pid { return p, true } } - return nil, false + return tailcfg.NodeView{}, false } // printConciseHeader prints a concise header line representing nm to buf. @@ -172,54 +230,44 @@ func (nm *NetworkMap) PeerWithStableID(pid tailcfg.StableNodeID) (_ *tailcfg.Nod // in equalConciseHeader in sync. func (nm *NetworkMap) printConciseHeader(buf *strings.Builder) { fmt.Fprintf(buf, "netmap: self: %v auth=%v", - nm.NodeKey.ShortString(), nm.MachineStatus) - login := nm.UserProfiles[nm.User].LoginName + nm.NodeKey.ShortString(), nm.GetMachineStatus()) + login := nm.UserProfiles[nm.User()].LoginName if login == "" { - if nm.User.IsZero() { + if nm.User().IsZero() { login = "?" } else { - login = fmt.Sprint(nm.User) + login = fmt.Sprint(nm.User()) } } fmt.Fprintf(buf, " u=%s", login) - if nm.Debug != nil { - j, _ := json.Marshal(nm.Debug) - fmt.Fprintf(buf, " debug=%s", j) - } - fmt.Fprintf(buf, " %v", nm.Addresses) + fmt.Fprintf(buf, " %v", nm.GetAddresses().AsSlice()) buf.WriteByte('\n') } // equalConciseHeader reports whether a and b are equal for the fields // used by printConciseHeader. func (a *NetworkMap) equalConciseHeader(b *NetworkMap) bool { - if a.NodeKey != b.NodeKey || - a.MachineStatus != b.MachineStatus || - a.User != b.User || - len(a.Addresses) != len(b.Addresses) { - return false - } - for i, a := range a.Addresses { - if b.Addresses[i] != a { - return false - } - } - return (a.Debug == nil && b.Debug == nil) || reflect.DeepEqual(a.Debug, b.Debug) + return a.NodeKey == b.NodeKey && + a.GetMachineStatus() == b.GetMachineStatus() && + a.User() == b.User() && + views.SliceEqual(a.GetAddresses(), b.GetAddresses()) } // printPeerConcise appends to buf a line representing the peer p. // // If this function is changed to access different fields of p, keep // in nodeConciseEqual in sync. -func printPeerConcise(buf *strings.Builder, p *tailcfg.Node) { - aip := make([]string, len(p.AllowedIPs)) - for i, a := range p.AllowedIPs { +func printPeerConcise(buf *strings.Builder, p tailcfg.NodeView) { + aip := make([]string, p.AllowedIPs().Len()) + for i := range aip { + a := p.AllowedIPs().At(i) s := strings.TrimSuffix(fmt.Sprint(a), "/32") aip[i] = s } - ep := make([]string, len(p.Endpoints)) - for i, e := range p.Endpoints { + ep := make([]string, p.Endpoints().Len()) + for i := range ep { + e := p.Endpoints().At(i) // Align vertically on the ':' between IP and port colon := strings.IndexByte(e, ':') spaces := 0 @@ -230,21 +278,21 @@ func printPeerConcise(buf *strings.Builder, p *tailcfg.Node) { ep[i] = fmt.Sprintf("%21v", e+strings.Repeat(" ", spaces)) } - derp := p.DERP + derp := p.DERP() const derpPrefix = "127.3.3.40:" if strings.HasPrefix(derp, derpPrefix) { derp = "D" + derp[len(derpPrefix):] } var discoShort string - if !p.DiscoKey.IsZero() { - discoShort = p.DiscoKey.ShortString() + " " + if !p.DiscoKey().IsZero() { + discoShort = p.DiscoKey().ShortString() + " " } // Most of the time, aip is just one element, so format the // table to look good in that case. This will also make multi- // subnet nodes stand out visually. fmt.Fprintf(buf, " %v %s%-2v %-15v : %v\n", - p.Key.ShortString(), + p.Key().ShortString(), discoShort, derp, strings.Join(aip, " "), @@ -252,12 +300,12 @@ func printPeerConcise(buf *strings.Builder, p *tailcfg.Node) { } // nodeConciseEqual reports whether a and b are equal for the fields accessed by printPeerConcise. -func nodeConciseEqual(a, b *tailcfg.Node) bool { - return a.Key == b.Key && - a.DERP == b.DERP && - a.DiscoKey == b.DiscoKey && - eqCIDRsIgnoreNil(a.AllowedIPs, b.AllowedIPs) && - eqStringsIgnoreNil(a.Endpoints, b.Endpoints) +func nodeConciseEqual(a, b tailcfg.NodeView) bool { + return a.Key() == b.Key() && + a.DERP() == b.DERP() && + a.DiscoKey() == b.DiscoKey() && + views.SliceEqual(a.AllowedIPs(), b.AllowedIPs()) && + views.SliceEqual(a.Endpoints(), b.Endpoints()) } func (b *NetworkMap) ConciseDiffFrom(a *NetworkMap) string { @@ -276,7 +324,7 @@ func (b *NetworkMap) ConciseDiffFrom(a *NetworkMap) string { for len(aps) > 0 && len(bps) > 0 { pa, pb := aps[0], bps[0] switch { - case pa.ID == pb.ID: + case pa.ID() == pb.ID(): if !nodeConciseEqual(pa, pb) { diff.WriteByte('-') printPeerConcise(&diff, pa) @@ -284,12 +332,12 @@ func (b *NetworkMap) ConciseDiffFrom(a *NetworkMap) string { printPeerConcise(&diff, pb) } aps, bps = aps[1:], bps[1:] - case pa.ID > pb.ID: + case pa.ID() > pb.ID(): // New peer in b. diff.WriteByte('+') printPeerConcise(&diff, pb) bps = bps[1:] - case pb.ID > pa.ID: + case pb.ID() > pa.ID(): // Deleted peer in b. diff.WriteByte('-') printPeerConcise(&diff, pa) @@ -323,31 +371,3 @@ const ( AllowSingleHosts WGConfigFlags = 1 << iota AllowSubnetRoutes ) - -// eqStringsIgnoreNil reports whether a and b have the same length and -// contents, but ignore whether a or b are nil. -func eqStringsIgnoreNil(a, b []string) bool { - if len(a) != len(b) { - return false - } - for i, v := range a { - if v != b[i] { - return false - } - } - return true -} - -// eqCIDRsIgnoreNil reports whether a and b have the same length and -// contents, but ignore whether a or b are nil. -func eqCIDRsIgnoreNil(a, b []netip.Prefix) bool { - if len(a) != len(b) { - return false - } - for i, v := range a { - if v != b[i] { - return false - } - } - return true -} diff --git a/vendor/tailscale.com/types/netmap/nodemut.go b/vendor/tailscale.com/types/netmap/nodemut.go new file mode 100644 index 0000000000..919fe0492b --- /dev/null +++ b/vendor/tailscale.com/types/netmap/nodemut.go @@ -0,0 +1,191 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netmap + +import ( + "fmt" + "net/netip" + "reflect" + "slices" + "sync" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" + "tailscale.com/util/cmpx" +) + +// NodeMutation is the common interface for types that describe +// the change of a node's state. +type NodeMutation interface { + NodeIDBeingMutated() tailcfg.NodeID + Apply(*tailcfg.Node) +} + +type mutatingNodeID tailcfg.NodeID + +func (m mutatingNodeID) NodeIDBeingMutated() tailcfg.NodeID { return tailcfg.NodeID(m) } + +// NodeMutationDERPHome is a NodeMutation that says a node +// has changed its DERP home region. +type NodeMutationDERPHome struct { + mutatingNodeID + DERPRegion int +} + +func (m NodeMutationDERPHome) Apply(n *tailcfg.Node) { + n.DERP = fmt.Sprintf("127.3.3.40:%v", m.DERPRegion) +} + +// NodeMutation is a NodeMutation that says a node's endpoints have changed. +type NodeMutationEndpoints struct { + mutatingNodeID + Endpoints []netip.AddrPort +} + +func (m NodeMutationEndpoints) Apply(n *tailcfg.Node) { + eps := make([]string, len(m.Endpoints)) + for i, ep := range m.Endpoints { + eps[i] = ep.String() + } + n.Endpoints = eps +} + +// NodeMutationOnline is a NodeMutation that says a node is now online or +// offline. +type NodeMutationOnline struct { + mutatingNodeID + Online bool +} + +func (m NodeMutationOnline) Apply(n *tailcfg.Node) { + n.Online = ptr.To(m.Online) +} + +// NodeMutationLastSeen is a NodeMutation that says a node's LastSeen +// value should be set to the current time. +type NodeMutationLastSeen struct { + mutatingNodeID + LastSeen time.Time +} + +func (m NodeMutationLastSeen) Apply(n *tailcfg.Node) { + n.LastSeen = ptr.To(m.LastSeen) +} + +var peerChangeFields = sync.OnceValue(func() []reflect.StructField { + var fields []reflect.StructField + rt := reflect.TypeOf((*tailcfg.PeerChange)(nil)).Elem() + for i := 0; i < rt.NumField(); i++ { + fields = append(fields, rt.Field(i)) + } + return fields +}) + +// NodeMutationsFromPatch returns the NodeMutations that +// p describes. If p describes something not yet supported +// by a specific NodeMutation type, it returns (nil, false). +func NodeMutationsFromPatch(p *tailcfg.PeerChange) (_ []NodeMutation, ok bool) { + if p == nil || p.NodeID == 0 { + return nil, false + } + var ret []NodeMutation + rv := reflect.ValueOf(p).Elem() + for i, sf := range peerChangeFields() { + if rv.Field(i).IsZero() { + continue + } + switch sf.Name { + default: + // Unhandled field. + return nil, false + case "NodeID": + continue + case "DERPRegion": + ret = append(ret, NodeMutationDERPHome{mutatingNodeID(p.NodeID), p.DERPRegion}) + case "Endpoints": + eps := make([]netip.AddrPort, len(p.Endpoints)) + for i, epStr := range p.Endpoints { + var err error + eps[i], err = netip.ParseAddrPort(epStr) + if err != nil { + return nil, false + } + } + ret = append(ret, NodeMutationEndpoints{mutatingNodeID(p.NodeID), eps}) + case "Online": + ret = append(ret, NodeMutationOnline{mutatingNodeID(p.NodeID), *p.Online}) + case "LastSeen": + ret = append(ret, NodeMutationLastSeen{mutatingNodeID(p.NodeID), *p.LastSeen}) + } + } + return ret, true +} + +// MutationsFromMapResponse returns all the discrete node mutations described +// by res. It returns ok=false if res contains any non-patch field as defined +// by mapResponseContainsNonPatchFields. +func MutationsFromMapResponse(res *tailcfg.MapResponse, now time.Time) (ret []NodeMutation, ok bool) { + if now.IsZero() { + now = time.Now() + } + if mapResponseContainsNonPatchFields(res) { + return nil, false + } + // All that remains is PeersChangedPatch, OnlineChange, and LastSeenChange. + + for _, p := range res.PeersChangedPatch { + deltas, ok := NodeMutationsFromPatch(p) + if !ok { + return nil, false + } + ret = append(ret, deltas...) + } + for nid, v := range res.OnlineChange { + ret = append(ret, NodeMutationOnline{mutatingNodeID(nid), v}) + } + for nid, v := range res.PeerSeenChange { + if v { + ret = append(ret, NodeMutationLastSeen{mutatingNodeID(nid), now}) + } + } + slices.SortStableFunc(ret, func(a, b NodeMutation) int { + return cmpx.Compare(a.NodeIDBeingMutated(), b.NodeIDBeingMutated()) + }) + return ret, true +} + +// mapResponseContainsNonPatchFields reports whether res contains only "patch" +// fields set (PeersChangedPatch primarily, but also including the legacy +// PeerSeenChange and OnlineChange fields). +// +// It ignores any of the meta fields that are handled by PollNetMap before the +// peer change handling gets involved. +// +// The purpose of this function is to ask whether this is a tricky enough +// MapResponse to warrant a full netmap update. When this returns false, it +// means the response can be handled incrementally, patching up the local state. +func mapResponseContainsNonPatchFields(res *tailcfg.MapResponse) bool { + return res.Node != nil || + res.DERPMap != nil || + res.DNSConfig != nil || + res.Domain != "" || + res.CollectServices != "" || + res.PacketFilter != nil || + res.UserProfiles != nil || + res.Health != nil || + res.SSHPolicy != nil || + res.TKAInfo != nil || + res.DomainDataPlaneAuditLogID != "" || + res.Debug != nil || + res.ControlDialPlan != nil || + res.ClientVersion != nil || + res.Peers != nil || + res.PeersRemoved != nil || + // PeersChanged is too coarse to be considered a patch. Also, we convert + // PeersChanged to PeersChangedPatch in patchifyPeersChanged before this + // function is called, so it should never be set anyway. But for + // completedness, and for tests, check it too: + res.PeersChanged != nil +} diff --git a/vendor/tailscale.com/types/opt/bool.go b/vendor/tailscale.com/types/opt/bool.go index 66ff1bd820..ca9c048d53 100644 --- a/vendor/tailscale.com/types/opt/bool.go +++ b/vendor/tailscale.com/types/opt/bool.go @@ -87,21 +87,15 @@ func (b Bool) MarshalJSON() ([]byte, error) { } func (b *Bool) UnmarshalJSON(j []byte) error { - // Note: written with a bunch of ifs instead of a switch - // because I'm sure the Go compiler optimizes away these - // []byte->string allocations in an == comparison, but I'm too - // lazy to check whether that's true in a switch also. - if string(j) == "true" { + switch string(j) { + case "true": *b = "true" - return nil - } - if string(j) == "false" { + case "false": *b = "false" - return nil - } - if string(j) == "null" { + case "null": *b = "unset" - return nil + default: + return fmt.Errorf("invalid opt.Bool value %q", j) } - return fmt.Errorf("invalid opt.Bool value %q", j) + return nil } diff --git a/vendor/tailscale.com/types/persist/persist.go b/vendor/tailscale.com/types/persist/persist.go index ce1f4c99e4..19df45dcbd 100644 --- a/vendor/tailscale.com/types/persist/persist.go +++ b/vendor/tailscale.com/types/persist/persist.go @@ -35,7 +35,6 @@ type Persist struct { PrivateNodeKey key.NodePrivate OldPrivateNodeKey key.NodePrivate // needed to request key rotation Provider string - LoginName string UserProfile tailcfg.UserProfile NetworkLockKey key.NLPrivate NodeID tailcfg.StableNodeID @@ -80,8 +79,7 @@ func (p *Persist) Equals(p2 *Persist) bool { p.PrivateNodeKey.Equal(p2.PrivateNodeKey) && p.OldPrivateNodeKey.Equal(p2.OldPrivateNodeKey) && p.Provider == p2.Provider && - p.LoginName == p2.LoginName && - p.UserProfile == p2.UserProfile && + p.UserProfile.Equal(&p2.UserProfile) && p.NetworkLockKey.Equal(p2.NetworkLockKey) && p.NodeID == p2.NodeID && reflect.DeepEqual(nilIfEmpty(p.DisallowedTKAStateIDs), nilIfEmpty(p2.DisallowedTKAStateIDs)) @@ -102,5 +100,5 @@ func (p *Persist) Pretty() string { nk = p.PublicNodeKey() } return fmt.Sprintf("Persist{lm=%v, o=%v, n=%v u=%#v}", - mk.ShortString(), ok.ShortString(), nk.ShortString(), p.LoginName) + mk.ShortString(), ok.ShortString(), nk.ShortString(), p.UserProfile.LoginName) } diff --git a/vendor/tailscale.com/types/persist/persist_clone.go b/vendor/tailscale.com/types/persist/persist_clone.go index 4bce7a03b6..121d906ece 100644 --- a/vendor/tailscale.com/types/persist/persist_clone.go +++ b/vendor/tailscale.com/types/persist/persist_clone.go @@ -19,6 +19,7 @@ func (src *Persist) Clone() *Persist { } dst := new(Persist) *dst = *src + dst.UserProfile = *src.UserProfile.Clone() dst.DisallowedTKAStateIDs = append(src.DisallowedTKAStateIDs[:0:0], src.DisallowedTKAStateIDs...) return dst } @@ -30,7 +31,6 @@ var _PersistCloneNeedsRegeneration = Persist(struct { PrivateNodeKey key.NodePrivate OldPrivateNodeKey key.NodePrivate Provider string - LoginName string UserProfile tailcfg.UserProfile NetworkLockKey key.NLPrivate NodeID tailcfg.StableNodeID diff --git a/vendor/tailscale.com/types/persist/persist_view.go b/vendor/tailscale.com/types/persist/persist_view.go index 81a61b4729..8c3173a234 100644 --- a/vendor/tailscale.com/types/persist/persist_view.go +++ b/vendor/tailscale.com/types/persist/persist_view.go @@ -65,13 +65,12 @@ func (v *PersistView) UnmarshalJSON(b []byte) error { func (v PersistView) LegacyFrontendPrivateMachineKey() key.MachinePrivate { return v.ж.LegacyFrontendPrivateMachineKey } -func (v PersistView) PrivateNodeKey() key.NodePrivate { return v.ж.PrivateNodeKey } -func (v PersistView) OldPrivateNodeKey() key.NodePrivate { return v.ж.OldPrivateNodeKey } -func (v PersistView) Provider() string { return v.ж.Provider } -func (v PersistView) LoginName() string { return v.ж.LoginName } -func (v PersistView) UserProfile() tailcfg.UserProfile { return v.ж.UserProfile } -func (v PersistView) NetworkLockKey() key.NLPrivate { return v.ж.NetworkLockKey } -func (v PersistView) NodeID() tailcfg.StableNodeID { return v.ж.NodeID } +func (v PersistView) PrivateNodeKey() key.NodePrivate { return v.ж.PrivateNodeKey } +func (v PersistView) OldPrivateNodeKey() key.NodePrivate { return v.ж.OldPrivateNodeKey } +func (v PersistView) Provider() string { return v.ж.Provider } +func (v PersistView) UserProfile() tailcfg.UserProfileView { return v.ж.UserProfile.View() } +func (v PersistView) NetworkLockKey() key.NLPrivate { return v.ж.NetworkLockKey } +func (v PersistView) NodeID() tailcfg.StableNodeID { return v.ж.NodeID } func (v PersistView) DisallowedTKAStateIDs() views.Slice[string] { return views.SliceOf(v.ж.DisallowedTKAStateIDs) } @@ -83,7 +82,6 @@ var _PersistViewNeedsRegeneration = Persist(struct { PrivateNodeKey key.NodePrivate OldPrivateNodeKey key.NodePrivate Provider string - LoginName string UserProfile tailcfg.UserProfile NetworkLockKey key.NLPrivate NodeID tailcfg.StableNodeID diff --git a/vendor/tailscale.com/types/views/views.go b/vendor/tailscale.com/types/views/views.go index edcd85189d..5a16712500 100644 --- a/vendor/tailscale.com/types/views/views.go +++ b/vendor/tailscale.com/types/views/views.go @@ -6,15 +6,16 @@ package views import ( + "bytes" "encoding/json" "errors" - "net/netip" + "maps" + "slices" - "golang.org/x/exp/slices" - "tailscale.com/net/tsaddr" + "go4.org/mem" ) -func unmarshalJSON[T any](b []byte, x *[]T) error { +func unmarshalSliceFromJSON[T any](b []byte, x *[]T) error { if *x != nil { return errors.New("already initialized") } @@ -24,6 +25,83 @@ func unmarshalJSON[T any](b []byte, x *[]T) error { return json.Unmarshal(b, x) } +// ByteSlice is a read-only accessor for types that are backed by a []byte. +type ByteSlice[T ~[]byte] struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж T +} + +// ByteSliceOf returns a ByteSlice for the provided slice. +func ByteSliceOf[T ~[]byte](x T) ByteSlice[T] { + return ByteSlice[T]{x} +} + +// Len returns the length of the slice. +func (v ByteSlice[T]) Len() int { + return len(v.ж) +} + +// IsNil reports whether the underlying slice is nil. +func (v ByteSlice[T]) IsNil() bool { + return v.ж == nil +} + +// Mem returns a read-only view of the underlying slice. +func (v ByteSlice[T]) Mem() mem.RO { + return mem.B(v.ж) +} + +// Equal reports whether the underlying slice is equal to b. +func (v ByteSlice[T]) Equal(b T) bool { + return bytes.Equal(v.ж, b) +} + +// EqualView reports whether the underlying slice is equal to b. +func (v ByteSlice[T]) EqualView(b ByteSlice[T]) bool { + return bytes.Equal(v.ж, b.ж) +} + +// AsSlice returns a copy of the underlying slice. +func (v ByteSlice[T]) AsSlice() T { + return v.AppendTo(v.ж[:0:0]) +} + +// AppendTo appends the underlying slice values to dst. +func (v ByteSlice[T]) AppendTo(dst T) T { + return append(dst, v.ж...) +} + +// LenIter returns a slice the same length as the v.Len(). +// The caller can then range over it to get the valid indexes. +// It does not allocate. +func (v ByteSlice[T]) LenIter() []struct{} { return make([]struct{}, len(v.ж)) } + +// At returns the byte at index `i` of the slice. +func (v ByteSlice[T]) At(i int) byte { return v.ж[i] } + +// SliceFrom returns v[i:]. +func (v ByteSlice[T]) SliceFrom(i int) ByteSlice[T] { return ByteSlice[T]{v.ж[i:]} } + +// SliceTo returns v[:i] +func (v ByteSlice[T]) SliceTo(i int) ByteSlice[T] { return ByteSlice[T]{v.ж[:i]} } + +// Slice returns v[i:j] +func (v ByteSlice[T]) Slice(i, j int) ByteSlice[T] { return ByteSlice[T]{v.ж[i:j]} } + +// MarshalJSON implements json.Marshaler. +func (v ByteSlice[T]) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } + +// UnmarshalJSON implements json.Unmarshaler. +func (v *ByteSlice[T]) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + return json.Unmarshal(b, &v.ж) +} + // StructView represents the corresponding StructView of a Viewable. The concrete types are // typically generated by tailscale.com/cmd/viewer. type StructView[T any] interface { @@ -50,8 +128,9 @@ func SliceOfViews[T ViewCloner[T, V], V StructView[T]](x []T) SliceView[T, V] { return SliceView[T, V]{x} } -// SliceView is a read-only wrapper around a struct which should only be exposed -// as a View. +// SliceView wraps []T to provide accessors which return an immutable view V of +// T. It is used to provide the equivalent of SliceOf([]V) without having to +// allocate []V from []T. type SliceView[T ViewCloner[T, V], V StructView[T]] struct { // ж is the underlying mutable value, named with a hard-to-type // character that looks pointy like a pointer. @@ -64,7 +143,7 @@ type SliceView[T ViewCloner[T, V], V StructView[T]] struct { func (v SliceView[T, V]) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } // UnmarshalJSON implements json.Unmarshaler. -func (v *SliceView[T, V]) UnmarshalJSON(b []byte) error { return unmarshalJSON(b, &v.ж) } +func (v *SliceView[T, V]) UnmarshalJSON(b []byte) error { return unmarshalSliceFromJSON(b, &v.ж) } // IsNil reports whether the underlying slice is nil. func (v SliceView[T, V]) IsNil() bool { return v.ж == nil } @@ -72,6 +151,11 @@ func (v SliceView[T, V]) IsNil() bool { return v.ж == nil } // Len returns the length of the slice. func (v SliceView[T, V]) Len() int { return len(v.ж) } +// LenIter returns a slice the same length as the v.Len(). +// The caller can then range over it to get the valid indexes. +// It does not allocate. +func (v SliceView[T, V]) LenIter() []struct{} { return make([]struct{}, len(v.ж)) } + // At returns a View of the element at index `i` of the slice. func (v SliceView[T, V]) At(i int) V { return v.ж[i].View() } @@ -119,7 +203,7 @@ func (v Slice[T]) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements json.Unmarshaler. func (v *Slice[T]) UnmarshalJSON(b []byte) error { - return unmarshalJSON(b, &v.ж) + return unmarshalSliceFromJSON(b, &v.ж) } // IsNil reports whether the underlying slice is nil. @@ -128,6 +212,11 @@ func (v Slice[T]) IsNil() bool { return v.ж == nil } // Len returns the length of the slice. func (v Slice[T]) Len() int { return len(v.ж) } +// LenIter returns a slice the same length as the v.Len(). +// The caller can then range over it to get the valid indexes. +// It does not allocate. +func (v Slice[T]) LenIter() []struct{} { return make([]struct{}, len(v.ж)) } + // At returns the element at index `i` of the slice. func (v Slice[T]) At(i int) T { return v.ж[i] } @@ -187,6 +276,21 @@ func SliceContains[T comparable](v Slice[T], e T) bool { return false } +// SliceContainsFunc reports whether f reports true for any element in v. +func SliceContainsFunc[T any](v Slice[T], f func(T) bool) bool { + for i := 0; i < v.Len(); i++ { + if f(v.At(i)) { + return true + } + } + return false +} + +// SliceEqual is like the standard library's slices.Equal, but for two views. +func SliceEqual[T comparable](a, b Slice[T]) bool { + return slices.Equal(a.ж, b.ж) +} + // SliceEqualAnyOrder reports whether a and b contain the same elements, regardless of order. // The underlying slices for a and b can be nil. func SliceEqualAnyOrder[T comparable](a, b Slice[T]) bool { @@ -218,79 +322,6 @@ func SliceEqualAnyOrder[T comparable](a, b Slice[T]) bool { return true } -// IPPrefixSlice is a read-only accessor for a slice of netip.Prefix. -type IPPrefixSlice struct { - ж Slice[netip.Prefix] -} - -// IPPrefixSliceOf returns a IPPrefixSlice for the provided slice. -func IPPrefixSliceOf(x []netip.Prefix) IPPrefixSlice { return IPPrefixSlice{SliceOf(x)} } - -// IsNil reports whether the underlying slice is nil. -func (v IPPrefixSlice) IsNil() bool { return v.ж.IsNil() } - -// Len returns the length of the slice. -func (v IPPrefixSlice) Len() int { return v.ж.Len() } - -// At returns the IPPrefix at index `i` of the slice. -func (v IPPrefixSlice) At(i int) netip.Prefix { return v.ж.At(i) } - -// AppendTo appends the underlying slice values to dst. -func (v IPPrefixSlice) AppendTo(dst []netip.Prefix) []netip.Prefix { - return v.ж.AppendTo(dst) -} - -// Unwrap returns the underlying Slice[netip.Prefix]. -func (v IPPrefixSlice) Unwrap() Slice[netip.Prefix] { - return v.ж -} - -// AsSlice returns a copy of underlying slice. -func (v IPPrefixSlice) AsSlice() []netip.Prefix { - return v.ж.AsSlice() -} - -// Filter returns a new slice, containing elements of v that match f. -func (v IPPrefixSlice) Filter(f func(netip.Prefix) bool) []netip.Prefix { - return tsaddr.FilterPrefixesCopy(v.ж.ж, f) -} - -// PrefixesContainsIP reports whether any IPPrefix contains IP. -func (v IPPrefixSlice) ContainsIP(ip netip.Addr) bool { - return tsaddr.PrefixesContainsIP(v.ж.ж, ip) -} - -// PrefixesContainsFunc reports whether f is true for any IPPrefix in the slice. -func (v IPPrefixSlice) ContainsFunc(f func(netip.Prefix) bool) bool { - return slices.ContainsFunc(v.ж.ж, f) -} - -// ContainsExitRoutes reports whether v contains ExitNode Routes. -func (v IPPrefixSlice) ContainsExitRoutes() bool { - return tsaddr.ContainsExitRoutes(v.ж.ж) -} - -// ContainsNonExitSubnetRoutes reports whether v contains Subnet -// Routes other than ExitNode Routes. -func (v IPPrefixSlice) ContainsNonExitSubnetRoutes() bool { - for i := 0; i < v.Len(); i++ { - if v.At(i).Bits() != 0 { - return true - } - } - return false -} - -// MarshalJSON implements json.Marshaler. -func (v IPPrefixSlice) MarshalJSON() ([]byte, error) { - return v.ж.MarshalJSON() -} - -// UnmarshalJSON implements json.Unmarshaler. -func (v *IPPrefixSlice) UnmarshalJSON(b []byte) error { - return v.ж.UnmarshalJSON(b) -} - // MapOf returns a view over m. It is the caller's responsibility to make sure K // and V is immutable, if this is being used to provide a read-only view over m. func MapOf[K comparable, V comparable](m map[K]V) Map[K, V] { @@ -332,6 +363,30 @@ func (m Map[K, V]) GetOk(k K) (V, bool) { return v, ok } +// MarshalJSON implements json.Marshaler. +func (m Map[K, V]) MarshalJSON() ([]byte, error) { + return json.Marshal(m.ж) +} + +// UnmarshalJSON implements json.Unmarshaler. +// It should only be called on an uninitialized Map. +func (m *Map[K, V]) UnmarshalJSON(b []byte) error { + if m.ж != nil { + return errors.New("already initialized") + } + return json.Unmarshal(b, &m.ж) +} + +// AsMap returns a shallow-clone of the underlying map. +// If V is a pointer type, it is the caller's responsibility to make sure +// the values are immutable. +func (m *Map[K, V]) AsMap() map[K]V { + if m == nil { + return nil + } + return maps.Clone(m.ж) +} + // MapRangeFn is the func called from a Map.Range call. // Implementations should return false to stop range. type MapRangeFn[K comparable, V any] func(k K, v V) (cont bool) diff --git a/vendor/tailscale.com/util/cmpx/cmpx.go b/vendor/tailscale.com/util/cmpx/cmpx.go index d747f0a1d9..007d9096a3 100644 --- a/vendor/tailscale.com/util/cmpx/cmpx.go +++ b/vendor/tailscale.com/util/cmpx/cmpx.go @@ -20,3 +20,40 @@ func Or[T comparable](list ...T) T { } return zero } + +// Ordered is cmp.Ordered from Go 1.21. +type Ordered interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 | + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | + ~float32 | ~float64 | + ~string +} + +// Compare returns +// +// -1 if x is less than y, +// 0 if x equals y, +// +1 if x is greater than y. +// +// For floating-point types, a NaN is considered less than any non-NaN, +// a NaN is considered equal to a NaN, and -0.0 is equal to 0.0. +func Compare[T Ordered](x, y T) int { + xNaN := isNaN(x) + yNaN := isNaN(y) + if xNaN && yNaN { + return 0 + } + if xNaN || x < y { + return -1 + } + if yNaN || x > y { + return +1 + } + return 0 +} + +// isNaN reports whether x is a NaN without requiring the math package. +// This will always return false if T is not floating-point. +func isNaN[T Ordered](x T) bool { + return x != x +} diff --git a/vendor/tailscale.com/util/deephash/deephash.go b/vendor/tailscale.com/util/deephash/deephash.go index 0795c2c073..d21ef7ca01 100644 --- a/vendor/tailscale.com/util/deephash/deephash.go +++ b/vendor/tailscale.com/util/deephash/deephash.go @@ -23,11 +23,13 @@ import ( "crypto/sha256" "encoding/binary" "encoding/hex" + "fmt" "reflect" "sync" "time" "tailscale.com/util/hashx" + "tailscale.com/util/set" ) // There is much overlap between the theory of serialization and hashing. @@ -152,12 +154,90 @@ func Hash[T any](v *T) Sum { return h.sum() } +// Option is an optional argument to HasherForType. +type Option interface { + isOption() +} + +type fieldFilterOpt struct { + t reflect.Type + fields set.Set[string] + includeOnMatch bool // true to include fields, false to exclude them +} + +func (fieldFilterOpt) isOption() {} + +func (f fieldFilterOpt) filterStructField(sf reflect.StructField) (include bool) { + if f.fields.Contains(sf.Name) { + return f.includeOnMatch + } + return !f.includeOnMatch +} + +// IncludeFields returns an option that modifies the hashing for T to only +// include the named struct fields. +// +// T must be a struct type, and must match the type of the value passed to +// HasherForType. +func IncludeFields[T any](fields ...string) Option { + return newFieldFilter[T](true, fields) +} + +// ExcludeFields returns an option that modifies the hashing for T to include +// all struct fields of T except those provided in fields. +// +// T must be a struct type, and must match the type of the value passed to +// HasherForType. +func ExcludeFields[T any](fields ...string) Option { + return newFieldFilter[T](false, fields) +} + +func newFieldFilter[T any](include bool, fields []string) Option { + var zero T + t := reflect.TypeOf(&zero).Elem() + fieldSet := set.Set[string]{} + for _, f := range fields { + if _, ok := t.FieldByName(f); !ok { + panic(fmt.Sprintf("unknown field %q for type %v", f, t)) + } + fieldSet.Add(f) + } + return fieldFilterOpt{t, fieldSet, include} +} + // HasherForType returns a hash that is specialized for the provided type. -func HasherForType[T any]() func(*T) Sum { +// +// HasherForType panics if the opts are invalid for the provided type. +// +// Currently, at most one option can be provided (IncludeFields or +// ExcludeFields) and its type must match the type of T. Those restrictions may +// be removed in the future, along with documentation about their precedence +// when combined. +func HasherForType[T any](opts ...Option) func(*T) Sum { var v *T seedOnce.Do(initSeed) + if len(opts) > 1 { + panic("HasherForType only accepts one optional argument") // for now + } t := reflect.TypeOf(v).Elem() - hash := lookupTypeHasher(t) + var hash typeHasherFunc + for _, o := range opts { + switch o := o.(type) { + default: + panic(fmt.Sprintf("unknown HasherOpt %T", o)) + case fieldFilterOpt: + if t.Kind() != reflect.Struct { + panic("HasherForStructTypeWithFieldFilter requires T of kind struct") + } + if t != o.t { + panic(fmt.Sprintf("field filter for type %v does not match HasherForType type %v", o.t, t)) + } + hash = makeStructHasher(t, o.filterStructField) + } + } + if hash == nil { + hash = lookupTypeHasher(t) + } return func(v *T) (s Sum) { // This logic is identical to Hash, but pull out a few statements. h := hasherPool.Get().(*hasher) @@ -225,7 +305,7 @@ func makeTypeHasher(t reflect.Type) typeHasherFunc { case reflect.Slice: return makeSliceHasher(t) case reflect.Struct: - return makeStructHasher(t) + return makeStructHasher(t, keepAllStructFields) case reflect.Map: return makeMapHasher(t) case reflect.Pointer: @@ -353,9 +433,12 @@ func makeSliceHasher(t reflect.Type) typeHasherFunc { } } -func makeStructHasher(t reflect.Type) typeHasherFunc { +func keepAllStructFields(keepField reflect.StructField) bool { return true } + +func makeStructHasher(t reflect.Type, keepField func(reflect.StructField) bool) typeHasherFunc { type fieldHasher struct { - idx int // index of field for reflect.Type.Field(n); negative if memory is directly hashable + idx int // index of field for reflect.Type.Field(n); negative if memory is directly hashable + keep bool hash typeHasherFunc // only valid if idx is not negative offset uintptr size uintptr @@ -365,8 +448,8 @@ func makeStructHasher(t reflect.Type) typeHasherFunc { init := func() { for i, numField := 0, t.NumField(); i < numField; i++ { sf := t.Field(i) - f := fieldHasher{i, nil, sf.Offset, sf.Type.Size()} - if typeIsMemHashable(sf.Type) { + f := fieldHasher{i, keepField(sf), nil, sf.Offset, sf.Type.Size()} + if f.keep && typeIsMemHashable(sf.Type) { f.idx = -1 } @@ -390,6 +473,9 @@ func makeStructHasher(t reflect.Type) typeHasherFunc { return func(h *hasher, p pointer) { once.Do(init) for _, field := range fields { + if !field.keep { + continue + } pf := p.structField(field.idx, field.offset, field.size) if field.idx < 0 { h.HashBytes(pf.asMemory(field.size)) diff --git a/vendor/tailscale.com/util/goroutines/goroutines.go b/vendor/tailscale.com/util/goroutines/goroutines.go index 7ca7dc6609..9758b07586 100644 --- a/vendor/tailscale.com/util/goroutines/goroutines.go +++ b/vendor/tailscale.com/util/goroutines/goroutines.go @@ -11,15 +11,16 @@ import ( "strconv" ) -// ScrubbedGoroutineDump returns the list of all current goroutines, but with the actual -// values of arguments scrubbed out, lest it contain some private key material. -func ScrubbedGoroutineDump() []byte { +// ScrubbedGoroutineDump returns either the current goroutine's stack or all +// goroutines' stacks, but with the actual values of arguments scrubbed out, +// lest it contain some private key material. +func ScrubbedGoroutineDump(all bool) []byte { var buf []byte // Grab stacks multiple times into increasingly larger buffer sizes // to minimize the risk that we blow past our iOS memory limit. for size := 1 << 10; size <= 1<<20; size += 1 << 10 { buf = make([]byte, size) - buf = buf[:runtime.Stack(buf, true)] + buf = buf[:runtime.Stack(buf, all)] if len(buf) < size { // It fit. break diff --git a/vendor/tailscale.com/util/linuxfw/helpers.go b/vendor/tailscale.com/util/linuxfw/helpers.go new file mode 100644 index 0000000000..7526d68ed4 --- /dev/null +++ b/vendor/tailscale.com/util/linuxfw/helpers.go @@ -0,0 +1,35 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package linuxfw + +import ( + "encoding/hex" + "fmt" + "strings" + "unicode" +) + +func formatMaybePrintable(b []byte) string { + // Remove a single trailing null, if any + if len(b) > 0 && b[len(b)-1] == 0 { + b = b[:len(b)-1] + } + + nonprintable := strings.IndexFunc(string(b), func(r rune) bool { + return r > unicode.MaxASCII || !unicode.IsPrint(r) + }) + if nonprintable >= 0 { + return "" + hex.EncodeToString(b) + } + return string(b) +} + +func formatPortRange(r [2]uint16) string { + if r == [2]uint16{0, 65535} { + return fmt.Sprintf(`any`) + } else if r[0] == r[1] { + return fmt.Sprintf(`%d`, r[0]) + } + return fmt.Sprintf(`%d-%d`, r[0], r[1]) +} diff --git a/vendor/tailscale.com/util/linuxfw/iptables.go b/vendor/tailscale.com/util/linuxfw/iptables.go new file mode 100644 index 0000000000..3cc612d033 --- /dev/null +++ b/vendor/tailscale.com/util/linuxfw/iptables.go @@ -0,0 +1,70 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// TODO(#8502): add support for more architectures +//go:build linux && (arm64 || amd64) + +package linuxfw + +import ( + "fmt" + "os/exec" + "strings" + "unicode" + + "tailscale.com/types/logger" + "tailscale.com/util/multierr" +) + +// DebugNetfilter prints debug information about iptables rules to the +// provided log function. +func DebugIptables(logf logger.Logf) error { + // unused. + return nil +} + +// DetectIptables returns the number of iptables rules that are present in the +// system, ignoring the default "ACCEPT" rule present in the standard iptables +// chains. +// +// It only returns an error when there is no iptables binary, or when iptables -S +// fails. In all other cases, it returns the number of non-default rules. +func DetectIptables() (int, error) { + // run "iptables -S" to get the list of rules using iptables + // exec.Command returns an error if the binary is not found + cmd := exec.Command("iptables", "-S") + output, err := cmd.Output() + ip6cmd := exec.Command("ip6tables", "-S") + ip6output, ip6err := ip6cmd.Output() + var allLines []string + outputStr := string(output) + lines := strings.Split(outputStr, "\n") + ip6outputStr := string(ip6output) + ip6lines := strings.Split(ip6outputStr, "\n") + switch { + case err == nil && ip6err == nil: + allLines = append(lines, ip6lines...) + case err == nil && ip6err != nil: + allLines = lines + case err != nil && ip6err == nil: + allLines = ip6lines + default: + return 0, FWModeNotSupportedError{ + Mode: FirewallModeIPTables, + Err: fmt.Errorf("iptables command run fail: %w", multierr.New(err, ip6err)), + } + } + + // count the number of non-default rules + count := 0 + for _, line := range allLines { + trimmedLine := strings.TrimLeftFunc(line, unicode.IsSpace) + if line != "" && strings.HasPrefix(trimmedLine, "-A") { + // if the line is not empty and starts with "-A", it is a rule appended not default + count++ + } + } + + // return the count of non-default rules + return count, nil +} diff --git a/vendor/tailscale.com/util/linuxfw/iptables_runner.go b/vendor/tailscale.com/util/linuxfw/iptables_runner.go new file mode 100644 index 0000000000..14f2fa5363 --- /dev/null +++ b/vendor/tailscale.com/util/linuxfw/iptables_runner.go @@ -0,0 +1,488 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package linuxfw + +import ( + "fmt" + "net/netip" + "os/exec" + "strings" + + "github.com/coreos/go-iptables/iptables" + "tailscale.com/net/tsaddr" + "tailscale.com/types/logger" + "tailscale.com/util/multierr" +) + +type iptablesInterface interface { + // Adding this interface for testing purposes so we can mock out + // the iptables library, in reality this is a wrapper to *iptables.IPTables. + Insert(table, chain string, pos int, args ...string) error + Append(table, chain string, args ...string) error + Exists(table, chain string, args ...string) (bool, error) + Delete(table, chain string, args ...string) error + ClearChain(table, chain string) error + NewChain(table, chain string) error + DeleteChain(table, chain string) error +} + +type iptablesRunner struct { + ipt4 iptablesInterface + ipt6 iptablesInterface + + v6Available bool + v6NATAvailable bool +} + +func checkIP6TablesExists() error { + // Some distros ship ip6tables separately from iptables. + if _, err := exec.LookPath("ip6tables"); err != nil { + return fmt.Errorf("path not found: %w", err) + } + return nil +} + +// NewIPTablesRunner constructs a NetfilterRunner that programs iptables rules. +// If the underlying iptables library fails to initialize, that error is +// returned. The runner probes for IPv6 support once at initialization time and +// if not found, no IPv6 rules will be modified for the lifetime of the runner. +func NewIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) { + ipt4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + if err != nil { + return nil, err + } + + supportsV6, supportsV6NAT := false, false + v6err := checkIPv6(logf) + ip6terr := checkIP6TablesExists() + switch { + case v6err != nil: + logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err) + case ip6terr != nil: + logf("disabling tunneled IPv6 due to missing ip6tables: %v", ip6terr) + default: + supportsV6 = true + supportsV6NAT = supportsV6 && checkSupportsV6NAT() + logf("v6nat = %v", supportsV6NAT) + } + + var ipt6 *iptables.IPTables + if supportsV6 { + ipt6, err = iptables.NewWithProtocol(iptables.ProtocolIPv6) + if err != nil { + return nil, err + } + } + return &iptablesRunner{ipt4, ipt6, supportsV6, supportsV6NAT}, nil +} + +// HasIPV6 returns true if the system supports IPv6. +func (i *iptablesRunner) HasIPV6() bool { + return i.v6Available +} + +// HasIPV6NAT returns true if the system supports IPv6 NAT. +func (i *iptablesRunner) HasIPV6NAT() bool { + return i.v6NATAvailable +} + +func isErrChainNotExist(err error) bool { + return errCode(err) == 1 +} + +// getIPTByAddr returns the iptablesInterface with correct IP family +// that we will be using for the given address. +func (i *iptablesRunner) getIPTByAddr(addr netip.Addr) iptablesInterface { + nf := i.ipt4 + if addr.Is6() { + nf = i.ipt6 + } + return nf +} + +// AddLoopbackRule adds an iptables rule to permit loopback traffic to +// a local Tailscale IP. +func (i *iptablesRunner) AddLoopbackRule(addr netip.Addr) error { + if err := i.getIPTByAddr(addr).Insert("filter", "ts-input", 1, "-i", "lo", "-s", addr.String(), "-j", "ACCEPT"); err != nil { + return fmt.Errorf("adding loopback allow rule for %q: %w", addr, err) + } + + return nil +} + +// tsChain returns the name of the tailscale sub-chain corresponding +// to the given "parent" chain (e.g. INPUT, FORWARD, ...). +func tsChain(chain string) string { + return "ts-" + strings.ToLower(chain) +} + +// DelLoopbackRule removes the iptables rule permitting loopback +// traffic to a Tailscale IP. +func (i *iptablesRunner) DelLoopbackRule(addr netip.Addr) error { + if err := i.getIPTByAddr(addr).Delete("filter", "ts-input", "-i", "lo", "-s", addr.String(), "-j", "ACCEPT"); err != nil { + return fmt.Errorf("deleting loopback allow rule for %q: %w", addr, err) + } + + return nil +} + +// getTables gets the available iptablesInterface in iptables runner. +func (i *iptablesRunner) getTables() []iptablesInterface { + if i.HasIPV6() { + return []iptablesInterface{i.ipt4, i.ipt6} + } + return []iptablesInterface{i.ipt4} +} + +// getNATTables gets the available iptablesInterface in iptables runner. +// If the system does not support IPv6 NAT, only the IPv4 iptablesInterface +// is returned. +func (i *iptablesRunner) getNATTables() []iptablesInterface { + if i.HasIPV6NAT() { + return i.getTables() + } + return []iptablesInterface{i.ipt4} +} + +// AddHooks inserts calls to tailscale's netfilter chains in +// the relevant main netfilter chains. The tailscale chains must +// already exist. If they do not, an error is returned. +func (i *iptablesRunner) AddHooks() error { + // divert inserts a jump to the tailscale chain in the given table/chain. + // If the jump already exists, it is a no-op. + divert := func(ipt iptablesInterface, table, chain string) error { + tsChain := tsChain(chain) + + args := []string{"-j", tsChain} + exists, err := ipt.Exists(table, chain, args...) + if err != nil { + return fmt.Errorf("checking for %v in %s/%s: %w", args, table, chain, err) + } + if exists { + return nil + } + if err := ipt.Insert(table, chain, 1, args...); err != nil { + return fmt.Errorf("adding %v in %s/%s: %w", args, table, chain, err) + } + return nil + } + + for _, ipt := range i.getTables() { + if err := divert(ipt, "filter", "INPUT"); err != nil { + return err + } + if err := divert(ipt, "filter", "FORWARD"); err != nil { + return err + } + } + + for _, ipt := range i.getNATTables() { + if err := divert(ipt, "nat", "POSTROUTING"); err != nil { + return err + } + } + return nil +} + +// AddChains creates custom Tailscale chains in netfilter via iptables +// if the ts-chain doesn't already exist. +func (i *iptablesRunner) AddChains() error { + // create creates a chain in the given table if it doesn't already exist. + // If the chain already exists, it is a no-op. + create := func(ipt iptablesInterface, table, chain string) error { + err := ipt.ClearChain(table, chain) + if isErrChainNotExist(err) { + // nonexistent chain. let's create it! + return ipt.NewChain(table, chain) + } + if err != nil { + return fmt.Errorf("setting up %s/%s: %w", table, chain, err) + } + return nil + } + + for _, ipt := range i.getTables() { + if err := create(ipt, "filter", "ts-input"); err != nil { + return err + } + if err := create(ipt, "filter", "ts-forward"); err != nil { + return err + } + } + + for _, ipt := range i.getNATTables() { + if err := create(ipt, "nat", "ts-postrouting"); err != nil { + return err + } + } + + return nil +} + +// AddBase adds some basic processing rules to be supplemented by +// later calls to other helpers. +func (i *iptablesRunner) AddBase(tunname string) error { + if err := i.addBase4(tunname); err != nil { + return err + } + if i.HasIPV6() { + if err := i.addBase6(tunname); err != nil { + return err + } + } + return nil +} + +// addBase4 adds some basic IPv6 processing rules to be +// supplemented by later calls to other helpers. +func (i *iptablesRunner) addBase4(tunname string) error { + // Only allow CGNAT range traffic to come from tailscale0. There + // is an exception carved out for ranges used by ChromeOS, for + // which we fall out of the Tailscale chain. + // + // Note, this will definitely break nodes that end up using the + // CGNAT range for other purposes :(. + args := []string{"!", "-i", tunname, "-s", tsaddr.ChromeOSVMRange().String(), "-j", "RETURN"} + if err := i.ipt4.Append("filter", "ts-input", args...); err != nil { + return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err) + } + args = []string{"!", "-i", tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"} + if err := i.ipt4.Append("filter", "ts-input", args...); err != nil { + return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err) + } + + // Forward all traffic from the Tailscale interface, and drop + // traffic to the tailscale interface by default. We use packet + // marks here so both filter/FORWARD and nat/POSTROUTING can match + // on these packets of interest. + // + // In particular, we only want to apply SNAT rules in + // nat/POSTROUTING to packets that originated from the Tailscale + // interface, but we can't match on the inbound interface in + // POSTROUTING. So instead, we match on the inbound interface in + // filter/FORWARD, and set a packet mark that nat/POSTROUTING can + // use to effectively run that same test again. + args = []string{"-i", tunname, "-j", "MARK", "--set-mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask} + if err := i.ipt4.Append("filter", "ts-forward", args...); err != nil { + return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) + } + args = []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "ACCEPT"} + if err := i.ipt4.Append("filter", "ts-forward", args...); err != nil { + return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) + } + args = []string{"-o", tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"} + if err := i.ipt4.Append("filter", "ts-forward", args...); err != nil { + return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) + } + args = []string{"-o", tunname, "-j", "ACCEPT"} + if err := i.ipt4.Append("filter", "ts-forward", args...); err != nil { + return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) + } + + return nil +} + +// addBase6 adds some basic IPv4 processing rules to be +// supplemented by later calls to other helpers. +func (i *iptablesRunner) addBase6(tunname string) error { + // TODO: only allow traffic from Tailscale's ULA range to come + // from tailscale0. + + args := []string{"-i", tunname, "-j", "MARK", "--set-mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask} + if err := i.ipt6.Append("filter", "ts-forward", args...); err != nil { + return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err) + } + args = []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "ACCEPT"} + if err := i.ipt6.Append("filter", "ts-forward", args...); err != nil { + return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err) + } + // TODO: drop forwarded traffic to tailscale0 from tailscale's ULA + // (see corresponding IPv4 CGNAT rule). + args = []string{"-o", tunname, "-j", "ACCEPT"} + if err := i.ipt6.Append("filter", "ts-forward", args...); err != nil { + return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err) + } + + return nil +} + +// DelChains removes the custom Tailscale chains from netfilter via iptables. +func (i *iptablesRunner) DelChains() error { + for _, ipt := range i.getTables() { + if err := delChain(ipt, "filter", "ts-input"); err != nil { + return err + } + if err := delChain(ipt, "filter", "ts-forward"); err != nil { + return err + } + } + + for _, ipt := range i.getNATTables() { + if err := delChain(ipt, "nat", "ts-postrouting"); err != nil { + return err + } + } + + return nil +} + +// DelBase empties but does not remove custom Tailscale chains from +// netfilter via iptables. +func (i *iptablesRunner) DelBase() error { + del := func(ipt iptablesInterface, table, chain string) error { + if err := ipt.ClearChain(table, chain); err != nil { + if isErrChainNotExist(err) { + // nonexistent chain. That's fine, since it's + // the desired state anyway. + return nil + } + return fmt.Errorf("flushing %s/%s: %w", table, chain, err) + } + return nil + } + + for _, ipt := range i.getTables() { + if err := del(ipt, "filter", "ts-input"); err != nil { + return err + } + if err := del(ipt, "filter", "ts-forward"); err != nil { + return err + } + } + for _, ipt := range i.getNATTables() { + if err := del(ipt, "nat", "ts-postrouting"); err != nil { + return err + } + } + + return nil +} + +// DelHooks deletes the calls to tailscale's netfilter chains +// in the relevant main netfilter chains. +func (i *iptablesRunner) DelHooks(logf logger.Logf) error { + for _, ipt := range i.getTables() { + if err := delTSHook(ipt, "filter", "INPUT", logf); err != nil { + return err + } + if err := delTSHook(ipt, "filter", "FORWARD", logf); err != nil { + return err + } + } + for _, ipt := range i.getNATTables() { + if err := delTSHook(ipt, "nat", "POSTROUTING", logf); err != nil { + return err + } + } + + return nil +} + +// AddSNATRule adds a netfilter rule to SNAT traffic destined for +// local subnets. +func (i *iptablesRunner) AddSNATRule() error { + args := []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "MASQUERADE"} + for _, ipt := range i.getNATTables() { + if err := ipt.Append("nat", "ts-postrouting", args...); err != nil { + return fmt.Errorf("adding %v in nat/ts-postrouting: %w", args, err) + } + } + return nil +} + +// DelSNATRule removes the netfilter rule to SNAT traffic destined for +// local subnets. An error is returned if the rule does not exist. +func (i *iptablesRunner) DelSNATRule() error { + args := []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "MASQUERADE"} + for _, ipt := range i.getNATTables() { + if err := ipt.Delete("nat", "ts-postrouting", args...); err != nil { + return fmt.Errorf("deleting %v in nat/ts-postrouting: %w", args, err) + } + } + return nil +} + +// IPTablesCleanup removes all Tailscale added iptables rules. +// Any errors that occur are logged to the provided logf. +func IPTablesCleanup(logf logger.Logf) { + err := clearRules(iptables.ProtocolIPv4, logf) + if err != nil { + logf("linuxfw: clear iptables: %v", err) + } + + err = clearRules(iptables.ProtocolIPv6, logf) + if err != nil { + logf("linuxfw: clear ip6tables: %v", err) + } +} + +// delTSHook deletes hook in a chain that jumps to a ts-chain. If the hook does not +// exist, it's a no-op since the desired state is already achieved but we log the +// error because error code from the iptables module resists unwrapping. +func delTSHook(ipt iptablesInterface, table, chain string, logf logger.Logf) error { + tsChain := tsChain(chain) + args := []string{"-j", tsChain} + if err := ipt.Delete(table, chain, args...); err != nil { + // TODO(apenwarr): check for errCode(1) here. + // Unfortunately the error code from the iptables + // module resists unwrapping, unlike with other + // calls. So we have to assume if Delete fails, + // it's because there is no such rule. + logf("deleting %v in %s/%s: %v", args, table, chain, err) + return nil + } + return nil +} + +// delChain flushs and deletes a chain. If the chain does not exist, it's a no-op +// since the desired state is already achieved. otherwise, it returns an error. +func delChain(ipt iptablesInterface, table, chain string) error { + if err := ipt.ClearChain(table, chain); err != nil { + if isErrChainNotExist(err) { + // nonexistent chain. nothing to do. + return nil + } + return fmt.Errorf("flushing %s/%s: %w", table, chain, err) + } + if err := ipt.DeleteChain(table, chain); err != nil { + return fmt.Errorf("deleting %s/%s: %w", table, chain, err) + } + return nil +} + +// clearRules clears all the iptables rules created by Tailscale +// for the given protocol. If error occurs, it's logged but not returned. +func clearRules(proto iptables.Protocol, logf logger.Logf) error { + ipt, err := iptables.NewWithProtocol(proto) + if err != nil { + return err + } + + var errs []error + + if err := delTSHook(ipt, "filter", "INPUT", logf); err != nil { + errs = append(errs, err) + } + if err := delTSHook(ipt, "filter", "FORWARD", logf); err != nil { + errs = append(errs, err) + } + if err := delTSHook(ipt, "nat", "POSTROUTING", logf); err != nil { + errs = append(errs, err) + } + + if err := delChain(ipt, "filter", "ts-input"); err != nil { + errs = append(errs, err) + } + if err := delChain(ipt, "filter", "ts-forward"); err != nil { + errs = append(errs, err) + } + + if err := delChain(ipt, "nat", "ts-postrouting"); err != nil { + errs = append(errs, err) + } + + return multierr.New(errs...) +} diff --git a/vendor/tailscale.com/util/linuxfw/linuxfw.go b/vendor/tailscale.com/util/linuxfw/linuxfw.go new file mode 100644 index 0000000000..e381e1f52d --- /dev/null +++ b/vendor/tailscale.com/util/linuxfw/linuxfw.go @@ -0,0 +1,220 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package linuxfw returns the kind of firewall being used by the kernel. + +//go:build linux + +package linuxfw + +import ( + "bytes" + "errors" + "fmt" + "os" + "os/exec" + "strconv" + "strings" + + "github.com/tailscale/netlink" + "tailscale.com/types/logger" +) + +// MatchDecision is the decision made by the firewall for a packet matched by a rule. +// It is used to decide whether to accept or masquerade a packet in addMatchSubnetRouteMarkRule. +type MatchDecision int + +const ( + Accept MatchDecision = iota + Masq +) + +type FWModeNotSupportedError struct { + Mode FirewallMode + Err error +} + +func (e FWModeNotSupportedError) Error() string { + return fmt.Sprintf("firewall mode %q not supported: %v", e.Mode, e.Err) +} + +func (e FWModeNotSupportedError) Is(target error) bool { + _, ok := target.(FWModeNotSupportedError) + return ok +} + +func (e FWModeNotSupportedError) Unwrap() error { + return e.Err +} + +type FirewallMode string + +const ( + FirewallModeIPTables FirewallMode = "iptables" + FirewallModeNfTables FirewallMode = "nftables" +) + +// The following bits are added to packet marks for Tailscale use. +// +// We tried to pick bits sufficiently out of the way that it's +// unlikely to collide with existing uses. We have 4 bytes of mark +// bits to play with. We leave the lower byte alone on the assumption +// that sysadmins would use those. Kubernetes uses a few bits in the +// second byte, so we steer clear of that too. +// +// Empirically, most of the documentation on packet marks on the +// internet gives the impression that the marks are 16 bits +// wide. Based on this, we theorize that the upper two bytes are +// relatively unused in the wild, and so we consume bits 16:23 (the +// third byte). +// +// The constants are in the iptables/iproute2 string format for +// matching and setting the bits, so they can be directly embedded in +// commands. +const ( + // The mask for reading/writing the 'firewall mask' bits on a packet. + // See the comment on the const block on why we only use the third byte. + // + // We claim bits 16:23 entirely. For now we only use the lower four + // bits, leaving the higher 4 bits for future use. + TailscaleFwmarkMask = "0xff0000" + TailscaleFwmarkMaskNum = 0xff0000 + + // Packet is from Tailscale and to a subnet route destination, so + // is allowed to be routed through this machine. + TailscaleSubnetRouteMark = "0x40000" + TailscaleSubnetRouteMarkNum = 0x40000 + + // Packet was originated by tailscaled itself, and must not be + // routed over the Tailscale network. + TailscaleBypassMark = "0x80000" + TailscaleBypassMarkNum = 0x80000 +) + +// getTailscaleFwmarkMaskNeg returns the negation of TailscaleFwmarkMask in bytes. +func getTailscaleFwmarkMaskNeg() []byte { + return []byte{0xff, 0x00, 0xff, 0xff} +} + +// getTailscaleFwmarkMask returns the TailscaleFwmarkMask in bytes. +func getTailscaleFwmarkMask() []byte { + return []byte{0x00, 0xff, 0x00, 0x00} +} + +// getTailscaleSubnetRouteMark returns the TailscaleSubnetRouteMark in bytes. +func getTailscaleSubnetRouteMark() []byte { + return []byte{0x00, 0x04, 0x00, 0x00} +} + +// errCode extracts and returns the process exit code from err, or +// zero if err is nil. +func errCode(err error) int { + if err == nil { + return 0 + } + var e *exec.ExitError + if ok := errors.As(err, &e); ok { + return e.ExitCode() + } + s := err.Error() + if strings.HasPrefix(s, "exitcode:") { + code, err := strconv.Atoi(s[9:]) + if err == nil { + return code + } + } + return -42 +} + +// checkIPv6 checks whether the system appears to have a working IPv6 +// network stack. It returns an error explaining what looks wrong or +// missing. It does not check that IPv6 is currently functional or +// that there's a global address, just that the system would support +// IPv6 if it were on an IPv6 network. +func checkIPv6(logf logger.Logf) error { + _, err := os.Stat("/proc/sys/net/ipv6") + if os.IsNotExist(err) { + return err + } + bs, err := os.ReadFile("/proc/sys/net/ipv6/conf/all/disable_ipv6") + if err != nil { + // Be conservative if we can't find the IPv6 configuration knob. + return err + } + disabled, err := strconv.ParseBool(strings.TrimSpace(string(bs))) + if err != nil { + return errors.New("disable_ipv6 has invalid bool") + } + if disabled { + return errors.New("disable_ipv6 is set") + } + + // Older kernels don't support IPv6 policy routing. Some kernels + // support policy routing but don't have this knob, so absence of + // the knob is not fatal. + bs, err = os.ReadFile("/proc/sys/net/ipv6/conf/all/disable_policy") + if err == nil { + disabled, err = strconv.ParseBool(strings.TrimSpace(string(bs))) + if err != nil { + return errors.New("disable_policy has invalid bool") + } + if disabled { + return errors.New("disable_policy is set") + } + } + + if err := CheckIPRuleSupportsV6(logf); err != nil { + return fmt.Errorf("kernel doesn't support IPv6 policy routing: %w", err) + } + + return nil +} + +// checkSupportsV6NAT returns whether the system has a "nat" table in the +// IPv6 netfilter stack. +// +// The nat table was added after the initial release of ipv6 +// netfilter, so some older distros ship a kernel that can't NAT IPv6 +// traffic. +func checkSupportsV6NAT() bool { + bs, err := os.ReadFile("/proc/net/ip6_tables_names") + if err != nil { + // Can't read the file. Assume SNAT works. + return true + } + if bytes.Contains(bs, []byte("nat\n")) { + return true + } + // In nftables mode, that proc file will be empty. Try another thing: + if exec.Command("modprobe", "ip6table_nat").Run() == nil { + return true + } + return false +} + +func CheckIPRuleSupportsV6(logf logger.Logf) error { + // First try just a read-only operation to ideally avoid + // having to modify any state. + if rules, err := netlink.RuleList(netlink.FAMILY_V6); err != nil { + return fmt.Errorf("querying IPv6 policy routing rules: %w", err) + } else { + if len(rules) > 0 { + logf("[v1] kernel supports IPv6 policy routing (found %d rules)", len(rules)) + return nil + } + } + + // Try to actually create & delete one as a test. + rule := netlink.NewRule() + rule.Priority = 1234 + rule.Mark = TailscaleBypassMarkNum + rule.Table = 52 + rule.Family = netlink.FAMILY_V6 + // First delete the rule unconditionally, and don't check for + // errors. This is just cleaning up anything that might be already + // there. + netlink.RuleDel(rule) + // And clean up on exit. + defer netlink.RuleDel(rule) + return netlink.RuleAdd(rule) +} diff --git a/vendor/tailscale.com/util/linuxfw/linuxfw_unsupported.go b/vendor/tailscale.com/util/linuxfw/linuxfw_unsupported.go new file mode 100644 index 0000000000..4c6029af1a --- /dev/null +++ b/vendor/tailscale.com/util/linuxfw/linuxfw_unsupported.go @@ -0,0 +1,40 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// NOTE: linux_{arm64, amd64} are the only two currently supported archs due to missing +// support in upstream dependencies. + +// TODO(#8502): add support for more architectures +//go:build !linux || (linux && !(arm64 || amd64)) + +package linuxfw + +import ( + "errors" + + "tailscale.com/types/logger" +) + +// ErrUnsupported is the error returned from all functions on non-Linux +// platforms. +var ErrUnsupported = errors.New("linuxfw:unsupported") + +// DebugNetfilter is not supported on non-Linux platforms. +func DebugNetfilter(logf logger.Logf) error { + return ErrUnsupported +} + +// DetectNetfilter is not supported on non-Linux platforms. +func DetectNetfilter() (int, error) { + return 0, ErrUnsupported +} + +// DebugIptables is not supported on non-Linux platforms. +func DebugIptables(logf logger.Logf) error { + return ErrUnsupported +} + +// DetectIptables is not supported on non-Linux platforms. +func DetectIptables() (int, error) { + return 0, ErrUnsupported +} diff --git a/vendor/tailscale.com/util/linuxfw/nftables.go b/vendor/tailscale.com/util/linuxfw/nftables.go new file mode 100644 index 0000000000..afe6dfa6e3 --- /dev/null +++ b/vendor/tailscale.com/util/linuxfw/nftables.go @@ -0,0 +1,268 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// TODO(#8502): add support for more architectures +//go:build linux && (arm64 || amd64) + +package linuxfw + +import ( + "fmt" + "sort" + "strings" + + "github.com/google/nftables" + "github.com/google/nftables/expr" + "github.com/google/nftables/xt" + "github.com/josharian/native" + "golang.org/x/sys/unix" + "tailscale.com/types/logger" + "tailscale.com/util/cmpx" +) + +// DebugNetfilter prints debug information about netfilter rules to the +// provided log function. +func DebugNetfilter(logf logger.Logf) error { + conn, err := nftables.New() + if err != nil { + return err + } + + chains, err := conn.ListChains() + if err != nil { + return fmt.Errorf("cannot list chains: %w", err) + } + + if len(chains) == 0 { + logf("netfilter: no chains") + return nil + } + + for _, chain := range chains { + logf("netfilter: table=%s chain=%s", chain.Table.Name, chain.Name) + + rules, err := conn.GetRules(chain.Table, chain) + if err != nil { + continue + } + sort.Slice(rules, func(i, j int) bool { + return rules[i].Position < rules[j].Position + }) + + for i, rule := range rules { + logf("netfilter: rule[%d]: pos=%d flags=%d", i, rule.Position, rule.Flags) + for _, ex := range rule.Exprs { + switch v := ex.(type) { + case *expr.Meta: + key := cmpx.Or(metaKeyNames[v.Key], "UNKNOWN") + logf("netfilter: Meta: key=%s source_register=%v register=%d", key, v.SourceRegister, v.Register) + + case *expr.Cmp: + op := cmpx.Or(cmpOpNames[v.Op], "UNKNOWN") + logf("netfilter: Cmp: op=%s register=%d data=%s", op, v.Register, formatMaybePrintable(v.Data)) + + case *expr.Counter: + // don't print + + case *expr.Verdict: + kind := cmpx.Or(verdictNames[v.Kind], "UNKNOWN") + logf("netfilter: Verdict: kind=%s data=%s", kind, v.Chain) + + case *expr.Target: + logf("netfilter: Target: name=%s info=%s", v.Name, printTargetInfo(v.Name, v.Info)) + + case *expr.Match: + logf("netfilter: Match: name=%s info=%+v", v.Name, printMatchInfo(v.Name, v.Info)) + + case *expr.Payload: + logf("netfilter: Payload: op=%s src=%d dst=%d base=%s offset=%d len=%d", + payloadOperationTypeNames[v.OperationType], + v.SourceRegister, v.DestRegister, + payloadBaseNames[v.Base], + v.Offset, v.Len) + // TODO(andrew): csum + + case *expr.Bitwise: + var xor string + for _, b := range v.Xor { + if b != 0 { + xor = fmt.Sprintf(" xor=%v", v.Xor) + break + } + } + logf("netfilter: Bitwise: src=%d dst=%d len=%d mask=%v%s", + v.SourceRegister, v.DestRegister, v.Len, v.Mask, xor) + + default: + logf("netfilter: unknown %T: %+v", v, v) + } + } + } + } + + return nil +} + +// DetectNetfilter returns the number of nftables rules present in the system. +func DetectNetfilter() (int, error) { + conn, err := nftables.New() + if err != nil { + return 0, FWModeNotSupportedError{ + Mode: FirewallModeNfTables, + Err: err, + } + } + + chains, err := conn.ListChains() + if err != nil { + return 0, FWModeNotSupportedError{ + Mode: FirewallModeNfTables, + Err: fmt.Errorf("cannot list chains: %w", err), + } + } + + var validRules int + for _, chain := range chains { + rules, err := conn.GetRules(chain.Table, chain) + if err != nil { + continue + } + validRules += len(rules) + } + return validRules, nil +} + +func printMatchInfo(name string, info xt.InfoAny) string { + var sb strings.Builder + sb.WriteString(`{`) + + var handled bool = true + switch v := info.(type) { + // TODO(andrew): we should support these common types + //case *xt.ConntrackMtinfo3: + //case *xt.ConntrackMtinfo2: + case *xt.Tcp: + fmt.Fprintf(&sb, "Src:%s Dst:%s", formatPortRange(v.SrcPorts), formatPortRange(v.DstPorts)) + if v.Option != 0 { + fmt.Fprintf(&sb, " Option:%d", v.Option) + } + if v.FlagsMask != 0 { + fmt.Fprintf(&sb, " FlagsMask:%d", v.FlagsMask) + } + if v.FlagsCmp != 0 { + fmt.Fprintf(&sb, " FlagsCmp:%d", v.FlagsCmp) + } + if v.InvFlags != 0 { + fmt.Fprintf(&sb, " InvFlags:%d", v.InvFlags) + } + + case *xt.Udp: + fmt.Fprintf(&sb, "Src:%s Dst:%s", formatPortRange(v.SrcPorts), formatPortRange(v.DstPorts)) + if v.InvFlags != 0 { + fmt.Fprintf(&sb, " InvFlags:%d", v.InvFlags) + } + + case *xt.AddrType: + var sprefix, dprefix string + if v.InvertSource { + sprefix = "!" + } + if v.InvertDest { + dprefix = "!" + } + // TODO(andrew): translate source/dest + fmt.Fprintf(&sb, "Source:%s%d Dest:%s%d", sprefix, v.Source, dprefix, v.Dest) + + case *xt.AddrTypeV1: + // TODO(andrew): translate source/dest + fmt.Fprintf(&sb, "Source:%d Dest:%d", v.Source, v.Dest) + + var flags []string + for flag, name := range addrTypeFlagNames { + if v.Flags&flag != 0 { + flags = append(flags, name) + } + } + if len(flags) > 0 { + sort.Strings(flags) + fmt.Fprintf(&sb, "Flags:%s", strings.Join(flags, ",")) + } + + default: + handled = false + } + if handled { + sb.WriteString(`}`) + return sb.String() + } + + unknown, ok := info.(*xt.Unknown) + if !ok { + return fmt.Sprintf("(%T)%+v", info, info) + } + data := []byte(*unknown) + + // Things where upstream has no type + handled = true + switch name { + case "pkttype": + if len(data) != 8 { + handled = false + break + } + + pkttype := int(native.Endian.Uint32(data[0:4])) + invert := int(native.Endian.Uint32(data[4:8])) + var invertPrefix string + if invert != 0 { + invertPrefix = "!" + } + + pkttypeName := packetTypeNames[pkttype] + if pkttypeName != "" { + fmt.Fprintf(&sb, "PktType:%s%s", invertPrefix, pkttypeName) + } else { + fmt.Fprintf(&sb, "PktType:%s%d", invertPrefix, pkttype) + } + + default: + handled = true + } + + if !handled { + return fmt.Sprintf("(%T)%+v", info, info) + } + + sb.WriteString(`}`) + return sb.String() +} + +func printTargetInfo(name string, info xt.InfoAny) string { + var sb strings.Builder + sb.WriteString(`{`) + + unknown, ok := info.(*xt.Unknown) + if !ok { + return fmt.Sprintf("(%T)%+v", info, info) + } + data := []byte(*unknown) + + // Things where upstream has no type + switch name { + case "LOG": + if len(data) != 32 { + fmt.Fprintf(&sb, `Error:"bad size; want 32, got %d"`, len(data)) + break + } + + level := data[0] + logflags := data[1] + prefix := unix.ByteSliceToString(data[2:]) + fmt.Fprintf(&sb, "Level:%d LogFlags:%d Prefix:%q", level, logflags, prefix) + default: + return fmt.Sprintf("(%T)%+v", info, info) + } + + sb.WriteString(`}`) + return sb.String() +} diff --git a/vendor/tailscale.com/util/linuxfw/nftables_runner.go b/vendor/tailscale.com/util/linuxfw/nftables_runner.go new file mode 100644 index 0000000000..9f56c54230 --- /dev/null +++ b/vendor/tailscale.com/util/linuxfw/nftables_runner.go @@ -0,0 +1,1180 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package linuxfw + +import ( + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "net" + "net/netip" + "reflect" + "strings" + + "github.com/google/nftables" + "github.com/google/nftables/expr" + "tailscale.com/net/tsaddr" + "tailscale.com/types/logger" +) + +const ( + chainNameForward = "ts-forward" + chainNameInput = "ts-input" + chainNamePostrouting = "ts-postrouting" +) + +// chainTypeRegular is an nftables chain that does not apply to a hook. +const chainTypeRegular = "" + +type chainInfo struct { + table *nftables.Table + name string + chainType nftables.ChainType + chainHook *nftables.ChainHook + chainPriority *nftables.ChainPriority + chainPolicy *nftables.ChainPolicy +} + +type nftable struct { + Proto nftables.TableFamily + Filter *nftables.Table + Nat *nftables.Table +} + +// nftablesRunner implements a netfilterRunner using the netlink based nftables +// library. As nftables allows for arbitrary tables and chains, there is a need +// to follow conventions in order to integrate well with a surrounding +// ecosystem. The rules installed by nftablesRunner have the following +// properties: +// - Install rules that intend to take precedence over rules installed by +// other software. Tailscale provides packet filtering for tailnet traffic +// inside the daemon based on the tailnet ACL rules. +// - As nftables "accept" is not final, rules from high priority tables (low +// numbers) will fall through to lower priority tables (high numbers). In +// order to effectively be 'final', we install "jump" rules into conventional +// tables and chains that will reach an accept verdict inside those tables. +// - The table and chain conventions followed here are those used by +// `iptables-nft` and `ufw`, so that those tools co-exist and do not +// negatively affect Tailscale function. +type nftablesRunner struct { + conn *nftables.Conn + nft4 *nftable + nft6 *nftable + + v6Available bool + v6NATAvailable bool +} + +// createTableIfNotExist creates a nftables table via connection c if it does not exist within the given family. +func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) { + tables, err := c.ListTables() + if err != nil { + return nil, fmt.Errorf("get tables: %w", err) + } + for _, table := range tables { + if table.Name == name && table.Family == family { + return table, nil + } + } + + t := c.AddTable(&nftables.Table{ + Family: family, + Name: name, + }) + if err := c.Flush(); err != nil { + return nil, fmt.Errorf("add table: %w", err) + } + return t, nil +} + +type errorChainNotFound struct { + chainName string + tableName string +} + +func (e errorChainNotFound) Error() string { + return fmt.Sprintf("chain %s not found in table %s", e.chainName, e.tableName) +} + +// getChainFromTable returns the chain with the given name from the given table. +// Note that a chain name is unique within a table. +func getChainFromTable(c *nftables.Conn, table *nftables.Table, name string) (*nftables.Chain, error) { + chains, err := c.ListChainsOfTableFamily(table.Family) + if err != nil { + return nil, fmt.Errorf("list chains: %w", err) + } + + for _, chain := range chains { + // Table family is already checked so table name is unique + if chain.Table.Name == table.Name && chain.Name == name { + return chain, nil + } + } + + return nil, errorChainNotFound{table.Name, name} +} + +// getChainsFromTable returns all chains from the given table. +func getChainsFromTable(c *nftables.Conn, table *nftables.Table) ([]*nftables.Chain, error) { + chains, err := c.ListChainsOfTableFamily(table.Family) + if err != nil { + return nil, fmt.Errorf("list chains: %w", err) + } + + var ret []*nftables.Chain + for _, chain := range chains { + // Table family is already checked so table name is unique + if chain.Table.Name == table.Name { + ret = append(ret, chain) + } + } + + return ret, nil +} + +// isTSChain reports whether `name` begins with "ts-" (and is thus a +// Tailscale-managed chain). +func isTSChain(name string) bool { + return strings.HasPrefix(name, "ts-") +} + +// createChainIfNotExist creates a chain with the given name in the given table +// if it does not exist. +func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error { + chain, err := getChainFromTable(c, cinfo.table, cinfo.name) + if err != nil && !errors.Is(err, errorChainNotFound{cinfo.table.Name, cinfo.name}) { + return fmt.Errorf("get chain: %w", err) + } else if err == nil { + // The chain already exists. If it is a TS chain, check the + // type/hook/priority, but for "conventional chains" assume they're what + // we expect (in case iptables-nft/ufw make minor behavior changes in + // the future). + if isTSChain(chain.Name) && (chain.Type != cinfo.chainType || chain.Hooknum != cinfo.chainHook || chain.Priority != cinfo.chainPriority) { + return fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.name) + } + return nil + } + + _ = c.AddChain(&nftables.Chain{ + Name: cinfo.name, + Table: cinfo.table, + Type: cinfo.chainType, + Hooknum: cinfo.chainHook, + Priority: cinfo.chainPriority, + Policy: cinfo.chainPolicy, + }) + + if err := c.Flush(); err != nil { + return fmt.Errorf("add chain: %w", err) + } + + return nil +} + +// NewNfTablesRunner creates a new nftablesRunner without guaranteeing +// the existence of the tables and chains. +func NewNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) { + conn, err := nftables.New() + if err != nil { + return nil, fmt.Errorf("nftables connection: %w", err) + } + nft4 := &nftable{Proto: nftables.TableFamilyIPv4} + + v6err := checkIPv6(logf) + if v6err != nil { + logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err) + } + supportsV6 := v6err == nil + supportsV6NAT := supportsV6 && checkSupportsV6NAT() + + var nft6 *nftable + if supportsV6 { + logf("v6nat availability: %v", supportsV6NAT) + nft6 = &nftable{Proto: nftables.TableFamilyIPv6} + } + + // TODO(KevinLiang10): convert iptables rule to nftable rules if they exist in the iptables + + return &nftablesRunner{ + conn: conn, + nft4: nft4, + nft6: nft6, + v6Available: supportsV6, + v6NATAvailable: supportsV6NAT, + }, nil +} + +// newLoadSaddrExpr creates a new nftables expression that loads the source +// address of the packet into the given register. +func newLoadSaddrExpr(proto nftables.TableFamily, destReg uint32) (expr.Any, error) { + switch proto { + case nftables.TableFamilyIPv4: + return &expr.Payload{ + DestRegister: destReg, + Base: expr.PayloadBaseNetworkHeader, + Offset: 12, + Len: 4, + }, nil + case nftables.TableFamilyIPv6: + return &expr.Payload{ + DestRegister: destReg, + Base: expr.PayloadBaseNetworkHeader, + Offset: 8, + Len: 16, + }, nil + default: + return nil, fmt.Errorf("table family %v is neither IPv4 nor IPv6", proto) + } +} + +// HasIPV6 returns true if the system supports IPv6. +func (n *nftablesRunner) HasIPV6() bool { + return n.v6Available +} + +// HasIPV6NAT returns true if the system supports IPv6 NAT. +func (n *nftablesRunner) HasIPV6NAT() bool { + return n.v6NATAvailable +} + +// findRule iterates through the rules to find the rule with matching expressions. +func findRule(conn *nftables.Conn, rule *nftables.Rule) (*nftables.Rule, error) { + rules, err := conn.GetRules(rule.Table, rule.Chain) + if err != nil { + return nil, fmt.Errorf("get nftables rules: %w", err) + } + if len(rules) == 0 { + return nil, nil + } + +ruleLoop: + for _, r := range rules { + if len(r.Exprs) != len(rule.Exprs) { + continue + } + + for i, e := range r.Exprs { + // Skip counter expressions, as they will not match. + if _, ok := e.(*expr.Counter); ok { + continue + } + if !reflect.DeepEqual(e, rule.Exprs[i]) { + continue ruleLoop + } + } + return r, nil + } + + return nil, nil +} + +func createLoopbackRule( + proto nftables.TableFamily, + table *nftables.Table, + chain *nftables.Chain, + addr netip.Addr, +) (*nftables.Rule, error) { + saddrExpr, err := newLoadSaddrExpr(proto, 1) + if err != nil { + return nil, fmt.Errorf("newLoadSaddrExpr: %w", err) + } + loopBackRule := &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyIIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte("lo"), + }, + saddrExpr, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: addr.AsSlice(), + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + } + return loopBackRule, nil +} + +// insertLoopbackRule inserts the TS loop back rule into +// the given chain as the first rule if it does not exist. +func insertLoopbackRule( + conn *nftables.Conn, proto nftables.TableFamily, + table *nftables.Table, chain *nftables.Chain, addr netip.Addr) error { + + loopBackRule, err := createLoopbackRule(proto, table, chain, addr) + if err != nil { + return fmt.Errorf("create loopback rule: %w", err) + } + + // If TestDial is set, we are running in test mode and we should not + // find rule because header will mismatch. + if conn.TestDial == nil { + // Check if the rule already exists. + rule, err := findRule(conn, loopBackRule) + if err != nil { + return fmt.Errorf("find rule: %w", err) + } + if rule != nil { + // Rule already exists, no need to insert. + return nil + } + } + + // This inserts the rule to the top of the chain + _ = conn.InsertRule(loopBackRule) + + if err = conn.Flush(); err != nil { + return fmt.Errorf("insert rule: %w", err) + } + return nil +} + +// getNFTByAddr returns the nftables with correct IP family +// that we will be using for the given address. +func (n *nftablesRunner) getNFTByAddr(addr netip.Addr) *nftable { + if addr.Is6() { + return n.nft6 + } + return n.nft4 +} + +// AddLoopbackRule adds an nftables rule to permit loopback traffic to +// a local Tailscale IP. This rule is added only if it does not already exist. +func (n *nftablesRunner) AddLoopbackRule(addr netip.Addr) error { + nf := n.getNFTByAddr(addr) + + inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput) + if err != nil { + return fmt.Errorf("get input chain: %w", err) + } + + if err := insertLoopbackRule(n.conn, nf.Proto, nf.Filter, inputChain, addr); err != nil { + return fmt.Errorf("add loopback rule: %w", err) + } + + return nil +} + +// DelLoopbackRule removes the nftables rule permitting loopback +// traffic to a Tailscale IP. +func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error { + nf := n.getNFTByAddr(addr) + + inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput) + if err != nil { + return fmt.Errorf("get input chain: %w", err) + } + + loopBackRule, err := createLoopbackRule(nf.Proto, nf.Filter, inputChain, addr) + if err != nil { + return fmt.Errorf("create loopback rule: %w", err) + } + + existingLoopBackRule, err := findRule(n.conn, loopBackRule) + if err != nil { + return fmt.Errorf("find loop back rule: %w", err) + } + if existingLoopBackRule == nil { + // Rule does not exist, no need to delete. + return nil + } + + if err := n.conn.DelRule(existingLoopBackRule); err != nil { + return fmt.Errorf("delete rule: %w", err) + } + + return n.conn.Flush() +} + +// getTables gets the available nftable in nftables runner. +func (n *nftablesRunner) getTables() []*nftable { + if n.v6Available { + return []*nftable{n.nft4, n.nft6} + } + return []*nftable{n.nft4} +} + +// getNATTables gets the available nftable in nftables runner. +// If the system does not support IPv6 NAT, only the IPv4 nftable +// will be returned. +func (n *nftablesRunner) getNATTables() []*nftable { + if n.v6NATAvailable { + return n.getTables() + } + return []*nftable{n.nft4} +} + +// AddChains creates custom Tailscale chains in netfilter via nftables +// if the ts-chain doesn't already exist. +func (n *nftablesRunner) AddChains() error { + polAccept := nftables.ChainPolicyAccept + for _, table := range n.getTables() { + // Create the filter table if it doesn't exist, this table name is the same + // as the name used by iptables-nft and ufw. We install rules into the + // same conventional table so that `accept` verdicts from our jump + // chains are conclusive. + filter, err := createTableIfNotExist(n.conn, table.Proto, "filter") + if err != nil { + return fmt.Errorf("create table: %w", err) + } + table.Filter = filter + // Adding the "conventional chains" that are used by iptables-nft and ufw. + if err = createChainIfNotExist(n.conn, chainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil { + return fmt.Errorf("create forward chain: %w", err) + } + if err = createChainIfNotExist(n.conn, chainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil { + return fmt.Errorf("create input chain: %w", err) + } + // Adding the tailscale chains that contain our rules. + if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil { + return fmt.Errorf("create forward chain: %w", err) + } + if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil { + return fmt.Errorf("create input chain: %w", err) + } + } + + for _, table := range n.getNATTables() { + // Create the nat table if it doesn't exist, this table name is the same + // as the name used by iptables-nft and ufw. We install rules into the + // same conventional table so that `accept` verdicts from our jump + // chains are conclusive. + nat, err := createTableIfNotExist(n.conn, table.Proto, "nat") + if err != nil { + return fmt.Errorf("create table: %w", err) + } + table.Nat = nat + // Adding the "conventional chains" that are used by iptables-nft and ufw. + if err = createChainIfNotExist(n.conn, chainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil { + return fmt.Errorf("create postrouting chain: %w", err) + } + // Adding the tailscale chain that contains our rules. + if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil { + return fmt.Errorf("create postrouting chain: %w", err) + } + } + + return n.conn.Flush() +} + +// deleteChainIfExists deletes a chain if it exists. +func deleteChainIfExists(c *nftables.Conn, table *nftables.Table, name string) error { + chain, err := getChainFromTable(c, table, name) + if err != nil && !errors.Is(err, errorChainNotFound{table.Name, name}) { + return fmt.Errorf("get chain: %w", err) + } else if err != nil { + // If the chain doesn't exist, we don't need to delete it. + return nil + } + + c.FlushChain(chain) + c.DelChain(chain) + + if err := c.Flush(); err != nil { + return fmt.Errorf("flush and delete chain: %w", err) + } + + return nil +} + +// DelChains removes the custom Tailscale chains from netfilter via nftables. +func (n *nftablesRunner) DelChains() error { + for _, table := range n.getTables() { + if err := deleteChainIfExists(n.conn, table.Filter, chainNameForward); err != nil { + return fmt.Errorf("delete chain: %w", err) + } + if err := deleteChainIfExists(n.conn, table.Filter, chainNameInput); err != nil { + return fmt.Errorf("delete chain: %w", err) + } + } + + if err := deleteChainIfExists(n.conn, n.nft4.Nat, chainNamePostrouting); err != nil { + return fmt.Errorf("delete chain: %w", err) + } + + if n.v6NATAvailable { + if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil { + return fmt.Errorf("delete chain: %w", err) + } + } + + if err := n.conn.Flush(); err != nil { + return fmt.Errorf("flush: %w", err) + } + + return nil +} + +// createHookRule creates a rule to jump from a hooked chain to a regular chain. +func createHookRule(table *nftables.Table, fromChain *nftables.Chain, toChainName string) *nftables.Rule { + exprs := []expr.Any{ + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictJump, + Chain: toChainName, + }, + } + + rule := &nftables.Rule{ + Table: table, + Chain: fromChain, + Exprs: exprs, + } + + return rule +} + +// addHookRule adds a rule to jump from a hooked chain to a regular chain at top of the hooked chain. +func addHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error { + rule := createHookRule(table, fromChain, toChainName) + _ = conn.InsertRule(rule) + + if err := conn.Flush(); err != nil { + return fmt.Errorf("flush add rule: %w", err) + } + + return nil +} + +// AddHooks is adding rules to conventional chains like "FORWARD", "INPUT" and "POSTROUTING" +// in tables and jump from those chains to tailscale chains. +func (n *nftablesRunner) AddHooks() error { + conn := n.conn + + for _, table := range n.getTables() { + inputChain, err := getChainFromTable(conn, table.Filter, "INPUT") + if err != nil { + return fmt.Errorf("get INPUT chain: %w", err) + } + err = addHookRule(conn, table.Filter, inputChain, chainNameInput) + if err != nil { + return fmt.Errorf("Addhook: %w", err) + } + forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD") + if err != nil { + return fmt.Errorf("get FORWARD chain: %w", err) + } + err = addHookRule(conn, table.Filter, forwardChain, chainNameForward) + if err != nil { + return fmt.Errorf("Addhook: %w", err) + } + } + + for _, table := range n.getNATTables() { + postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING") + if err != nil { + return fmt.Errorf("get INPUT chain: %w", err) + } + err = addHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting) + if err != nil { + return fmt.Errorf("Addhook: %w", err) + } + } + return nil +} + +// delHookRule deletes a rule that jumps from a hooked chain to a regular chain. +func delHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error { + rule := createHookRule(table, fromChain, toChainName) + existingRule, err := findRule(conn, rule) + if err != nil { + return fmt.Errorf("Failed to find hook rule: %w", err) + } + + if existingRule == nil { + return nil + } + + _ = conn.DelRule(existingRule) + + if err := conn.Flush(); err != nil { + return fmt.Errorf("flush del hook rule: %w", err) + } + return nil +} + +// DelHooks is deleting the rules added to conventional chains to jump to tailscale chains. +func (n *nftablesRunner) DelHooks(logf logger.Logf) error { + conn := n.conn + + for _, table := range n.getTables() { + inputChain, err := getChainFromTable(conn, table.Filter, "INPUT") + if err != nil { + return fmt.Errorf("get INPUT chain: %w", err) + } + err = delHookRule(conn, table.Filter, inputChain, chainNameInput) + if err != nil { + return fmt.Errorf("delhook: %w", err) + } + forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD") + if err != nil { + return fmt.Errorf("get FORWARD chain: %w", err) + } + err = delHookRule(conn, table.Filter, forwardChain, chainNameForward) + if err != nil { + return fmt.Errorf("delhook: %w", err) + } + } + + for _, table := range n.getNATTables() { + postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING") + if err != nil { + return fmt.Errorf("get INPUT chain: %w", err) + } + err = delHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting) + if err != nil { + return fmt.Errorf("delhook: %w", err) + } + } + + return nil +} + +// maskof returns the mask of the given prefix in big endian bytes. +func maskof(pfx netip.Prefix) []byte { + mask := make([]byte, 4) + binary.BigEndian.PutUint32(mask, ^(uint32(0xffff_ffff) >> pfx.Bits())) + return mask +} + +// createRangeRule creates a rule that matches packets with source IP from the give +// range (like CGNAT range or ChromeOSVM range) and the interface is not the tunname, +// and makes the given decision. Only IPv4 is supported. +func createRangeRule( + table *nftables.Table, chain *nftables.Chain, + tunname string, rng netip.Prefix, decision expr.VerdictKind, +) (*nftables.Rule, error) { + if rng.Addr().Is6() { + return nil, errors.New("IPv6 is not supported") + } + saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1) + if err != nil { + return nil, fmt.Errorf("newLoadSaddrExpr: %w", err) + } + netip := rng.Addr().AsSlice() + mask := maskof(rng) + rule := &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte(tunname), + }, + saddrExpr, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: mask, + Xor: []byte{0x00, 0x00, 0x00, 0x00}, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: netip, + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: decision, + }, + }, + } + return rule, nil + +} + +// addReturnChromeOSVMRangeRule adds a rule to return if the source IP +// is in the ChromeOS VM range. +func addReturnChromeOSVMRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { + rule, err := createRangeRule(table, chain, tunname, tsaddr.ChromeOSVMRange(), expr.VerdictReturn) + if err != nil { + return fmt.Errorf("create rule: %w", err) + } + _ = c.AddRule(rule) + if err = c.Flush(); err != nil { + return fmt.Errorf("add rule: %w", err) + } + return nil +} + +// addDropCGNATRangeRule adds a rule to drop if the source IP is in the +// CGNAT range. +func addDropCGNATRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { + rule, err := createRangeRule(table, chain, tunname, tsaddr.CGNATRange(), expr.VerdictDrop) + if err != nil { + return fmt.Errorf("create rule: %w", err) + } + _ = c.AddRule(rule) + if err = c.Flush(); err != nil { + return fmt.Errorf("add rule: %w", err) + } + return nil +} + +// createSetSubnetRouteMarkRule creates a rule to set the subnet route +// mark if the packet is from the given interface. +func createSetSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) { + hexTsFwmarkMaskNeg := getTailscaleFwmarkMaskNeg() + hexTSSubnetRouteMark := getTailscaleSubnetRouteMark() + + rule := &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte(tunname), + }, + &expr.Counter{}, + &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: hexTsFwmarkMaskNeg, + Xor: hexTSSubnetRouteMark, + }, + &expr.Meta{ + Key: expr.MetaKeyMARK, + SourceRegister: true, + Register: 1, + }, + }, + } + return rule, nil +} + +// addSetSubnetRouteMarkRule adds a rule to set the subnet route mark +// if the packet is from the given interface. +func addSetSubnetRouteMarkRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { + rule, err := createSetSubnetRouteMarkRule(table, chain, tunname) + if err != nil { + return fmt.Errorf("create rule: %w", err) + } + _ = c.AddRule(rule) + + if err := c.Flush(); err != nil { + return fmt.Errorf("add rule: %w", err) + } + + return nil +} + +// createDropOutgoingPacketFromCGNATRangeRuleWithTunname creates a rule to drop +// outgoing packets from the CGNAT range. +func createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) { + _, ipNet, err := net.ParseCIDR(tsaddr.CGNATRange().String()) + if err != nil { + return nil, fmt.Errorf("parse cidr: %v", err) + } + mask, err := hex.DecodeString(ipNet.Mask.String()) + if err != nil { + return nil, fmt.Errorf("decode mask: %v", err) + } + netip := ipNet.IP.Mask(ipNet.Mask).To4() + saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1) + if err != nil { + return nil, fmt.Errorf("newLoadSaddrExpr: %v", err) + } + rule := &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte(tunname), + }, + saddrExpr, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: mask, + Xor: []byte{0x00, 0x00, 0x00, 0x00}, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: netip, + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictDrop, + }, + }, + } + return rule, nil +} + +// addDropOutgoingPacketFromCGNATRangeRuleWithTunname adds a rule to drop +// outgoing packets from the CGNAT range. +func addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { + rule, err := createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table, chain, tunname) + if err != nil { + return fmt.Errorf("create rule: %w", err) + } + _ = conn.AddRule(rule) + + if err := conn.Flush(); err != nil { + return fmt.Errorf("add rule: %w", err) + } + return nil +} + +// createAcceptOutgoingPacketRule creates a rule to accept outgoing packets +// from the given interface. +func createAcceptOutgoingPacketRule(table *nftables.Table, chain *nftables.Chain, tunname string) *nftables.Rule { + return &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte(tunname), + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + } +} + +// addAcceptOutgoingPacketRule adds a rule to accept outgoing packets +// from the given interface. +func addAcceptOutgoingPacketRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { + rule := createAcceptOutgoingPacketRule(table, chain, tunname) + _ = conn.AddRule(rule) + + if err := conn.Flush(); err != nil { + return fmt.Errorf("flush add rule: %w", err) + } + + return nil +} + +// AddBase adds some basic processing rules. +func (n *nftablesRunner) AddBase(tunname string) error { + if err := n.addBase4(tunname); err != nil { + return fmt.Errorf("add base v4: %w", err) + } + if n.HasIPV6() { + if err := n.addBase6(tunname); err != nil { + return fmt.Errorf("add base v6: %w", err) + } + } + return nil +} + +// addBase4 adds some basic IPv4 processing rules. +func (n *nftablesRunner) addBase4(tunname string) error { + conn := n.conn + + inputChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameInput) + if err != nil { + return fmt.Errorf("get input chain v4: %v", err) + } + if err = addReturnChromeOSVMRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { + return fmt.Errorf("add return chromeos vm range rule v4: %w", err) + } + if err = addDropCGNATRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { + return fmt.Errorf("add drop cgnat range rule v4: %w", err) + } + + forwardChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameForward) + if err != nil { + return fmt.Errorf("get forward chain v4: %v", err) + } + + if err = addSetSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil { + return fmt.Errorf("add set subnet route mark rule v4: %w", err) + } + + if err = addMatchSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, Accept); err != nil { + return fmt.Errorf("add match subnet route mark rule v4: %w", err) + } + + if err = addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn, n.nft4.Filter, forwardChain, tunname); err != nil { + return fmt.Errorf("add drop outgoing packet from cgnat range rule v4: %w", err) + } + + if err = addAcceptOutgoingPacketRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil { + return fmt.Errorf("add accept outgoing packet rule v4: %w", err) + } + + if err = conn.Flush(); err != nil { + return fmt.Errorf("flush base v4: %w", err) + } + + return nil +} + +// addBase6 adds some basic IPv6 processing rules. +func (n *nftablesRunner) addBase6(tunname string) error { + conn := n.conn + + forwardChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameForward) + if err != nil { + return fmt.Errorf("get forward chain v6: %w", err) + } + + if err = addSetSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil { + return fmt.Errorf("add set subnet route mark rule v6: %w", err) + } + + if err = addMatchSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, Accept); err != nil { + return fmt.Errorf("add match subnet route mark rule v6: %w", err) + } + + if err = addAcceptOutgoingPacketRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil { + return fmt.Errorf("add accept outgoing packet rule v6: %w", err) + } + + if err = conn.Flush(); err != nil { + return fmt.Errorf("flush base v6: %w", err) + } + + return nil +} + +// DelBase empties, but does not remove, custom Tailscale chains from +// netfilter via iptables. +func (n *nftablesRunner) DelBase() error { + conn := n.conn + + for _, table := range n.getTables() { + inputChain, err := getChainFromTable(conn, table.Filter, chainNameInput) + if err != nil { + return fmt.Errorf("get input chain: %v", err) + } + conn.FlushChain(inputChain) + forwardChain, err := getChainFromTable(conn, table.Filter, chainNameForward) + if err != nil { + return fmt.Errorf("get forward chain: %v", err) + } + conn.FlushChain(forwardChain) + } + + for _, table := range n.getNATTables() { + postrouteChain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) + if err != nil { + return fmt.Errorf("get postrouting chain v4: %v", err) + } + conn.FlushChain(postrouteChain) + } + + return conn.Flush() +} + +// createMatchSubnetRouteMarkRule creates a rule that matches packets +// with the subnet route mark and takes the specified action. +func createMatchSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, action MatchDecision) (*nftables.Rule, error) { + hexTSFwmarkMask := getTailscaleFwmarkMask() + hexTSSubnetRouteMark := getTailscaleSubnetRouteMark() + + var endAction expr.Any + endAction = &expr.Verdict{Kind: expr.VerdictAccept} + if action == Masq { + endAction = &expr.Masq{} + } + + exprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: hexTSFwmarkMask, + Xor: []byte{0x00, 0x00, 0x00, 0x00}, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: hexTSSubnetRouteMark, + }, + &expr.Counter{}, + endAction, + } + + rule := &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: exprs, + } + return rule, nil +} + +// addMatchSubnetRouteMarkRule adds a rule that matches packets with +// the subnet route mark and takes the specified action. +func addMatchSubnetRouteMarkRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, action MatchDecision) error { + rule, err := createMatchSubnetRouteMarkRule(table, chain, action) + if err != nil { + return fmt.Errorf("create match subnet route mark rule: %w", err) + } + _ = conn.AddRule(rule) + + if err := conn.Flush(); err != nil { + return fmt.Errorf("flush add rule: %w", err) + } + + return nil +} + +// AddSNATRule adds a netfilter rule to SNAT traffic destined for +// local subnets. +func (n *nftablesRunner) AddSNATRule() error { + conn := n.conn + + for _, table := range n.getNATTables() { + chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) + if err != nil { + return fmt.Errorf("get postrouting chain v4: %w", err) + } + + if err = addMatchSubnetRouteMarkRule(conn, table.Nat, chain, Masq); err != nil { + return fmt.Errorf("add match subnet route mark rule v4: %w", err) + } + } + + if err := conn.Flush(); err != nil { + return fmt.Errorf("flush add SNAT rule: %w", err) + } + + return nil +} + +// DelSNATRule removes the netfilter rule to SNAT traffic destined for +// local subnets. An error is returned if the rule does not exist. +func (n *nftablesRunner) DelSNATRule() error { + conn := n.conn + + hexTSFwmarkMask := getTailscaleFwmarkMask() + hexTSSubnetRouteMark := getTailscaleSubnetRouteMark() + + exprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: hexTSFwmarkMask, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: hexTSSubnetRouteMark, + }, + &expr.Counter{}, + &expr.Masq{}, + } + + for _, table := range n.getNATTables() { + chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) + if err != nil { + return fmt.Errorf("get postrouting chain v4: %w", err) + } + + rule := &nftables.Rule{ + Table: table.Nat, + Chain: chain, + Exprs: exprs, + } + + SNATRule, err := findRule(conn, rule) + if err != nil { + return fmt.Errorf("find SNAT rule v4: %w", err) + } + + _ = conn.DelRule(SNATRule) + } + + if err := conn.Flush(); err != nil { + return fmt.Errorf("flush del SNAT rule: %w", err) + } + + return nil +} + +// cleanupChain removes a jump rule from hookChainName to tsChainName, and then +// the entire chain tsChainName. Errors are logged, but attempts to remove both +// the jump rule and chain continue even if one errors. +func cleanupChain(logf logger.Logf, conn *nftables.Conn, table *nftables.Table, hookChainName, tsChainName string) { + // remove the jump first, before removing the jump destination. + defaultChain, err := getChainFromTable(conn, table, hookChainName) + if err != nil && !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) { + logf("cleanup: did not find default chain: %s", err) + } + if !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) { + // delete hook in convention chain + _ = delHookRule(conn, table, defaultChain, tsChainName) + } + + tsChain, err := getChainFromTable(conn, table, tsChainName) + if err != nil && !errors.Is(err, errorChainNotFound{table.Name, tsChainName}) { + logf("cleanup: did not find ts-chain: %s", err) + } + + if tsChain != nil { + // flush and delete ts-chain + conn.FlushChain(tsChain) + conn.DelChain(tsChain) + err = conn.Flush() + logf("cleanup: delete and flush chain %s: %s", tsChainName, err) + } +} + +// NfTablesCleanUp removes all Tailscale added nftables rules. +// Any errors that occur are logged to the provided logf. +func NfTablesCleanUp(logf logger.Logf) { + conn, err := nftables.New() + if err != nil { + logf("cleanup: nftables connection: %s", err) + } + + tables, err := conn.ListTables() // both v4 and v6 + if err != nil { + logf("cleanup: list tables: %s", err) + } + + for _, table := range tables { + // These table names were used briefly in 1.48.0. + if table.Name == "ts-filter" || table.Name == "ts-nat" { + conn.DelTable(table) + if err := conn.Flush(); err != nil { + logf("cleanup: flush delete table %s: %s", table.Name, err) + } + } + + if table.Name == "filter" { + cleanupChain(logf, conn, table, "INPUT", chainNameInput) + cleanupChain(logf, conn, table, "FORWARD", chainNameForward) + } + if table.Name == "nat" { + cleanupChain(logf, conn, table, "POSTROUTING", chainNamePostrouting) + } + } +} diff --git a/vendor/tailscale.com/util/linuxfw/nftables_types.go b/vendor/tailscale.com/util/linuxfw/nftables_types.go new file mode 100644 index 0000000000..b6e24d2a67 --- /dev/null +++ b/vendor/tailscale.com/util/linuxfw/nftables_types.go @@ -0,0 +1,95 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// TODO(#8502): add support for more architectures +//go:build linux && (arm64 || amd64) + +package linuxfw + +import ( + "github.com/google/nftables/expr" + "github.com/google/nftables/xt" +) + +var metaKeyNames = map[expr.MetaKey]string{ + expr.MetaKeyLEN: "LEN", + expr.MetaKeyPROTOCOL: "PROTOCOL", + expr.MetaKeyPRIORITY: "PRIORITY", + expr.MetaKeyMARK: "MARK", + expr.MetaKeyIIF: "IIF", + expr.MetaKeyOIF: "OIF", + expr.MetaKeyIIFNAME: "IIFNAME", + expr.MetaKeyOIFNAME: "OIFNAME", + expr.MetaKeyIIFTYPE: "IIFTYPE", + expr.MetaKeyOIFTYPE: "OIFTYPE", + expr.MetaKeySKUID: "SKUID", + expr.MetaKeySKGID: "SKGID", + expr.MetaKeyNFTRACE: "NFTRACE", + expr.MetaKeyRTCLASSID: "RTCLASSID", + expr.MetaKeySECMARK: "SECMARK", + expr.MetaKeyNFPROTO: "NFPROTO", + expr.MetaKeyL4PROTO: "L4PROTO", + expr.MetaKeyBRIIIFNAME: "BRIIIFNAME", + expr.MetaKeyBRIOIFNAME: "BRIOIFNAME", + expr.MetaKeyPKTTYPE: "PKTTYPE", + expr.MetaKeyCPU: "CPU", + expr.MetaKeyIIFGROUP: "IIFGROUP", + expr.MetaKeyOIFGROUP: "OIFGROUP", + expr.MetaKeyCGROUP: "CGROUP", + expr.MetaKeyPRANDOM: "PRANDOM", +} + +var cmpOpNames = map[expr.CmpOp]string{ + expr.CmpOpEq: "EQ", + expr.CmpOpNeq: "NEQ", + expr.CmpOpLt: "LT", + expr.CmpOpLte: "LTE", + expr.CmpOpGt: "GT", + expr.CmpOpGte: "GTE", +} + +var verdictNames = map[expr.VerdictKind]string{ + expr.VerdictReturn: "RETURN", + expr.VerdictGoto: "GOTO", + expr.VerdictJump: "JUMP", + expr.VerdictBreak: "BREAK", + expr.VerdictContinue: "CONTINUE", + expr.VerdictDrop: "DROP", + expr.VerdictAccept: "ACCEPT", + expr.VerdictStolen: "STOLEN", + expr.VerdictQueue: "QUEUE", + expr.VerdictRepeat: "REPEAT", + expr.VerdictStop: "STOP", +} + +var payloadOperationTypeNames = map[expr.PayloadOperationType]string{ + expr.PayloadLoad: "LOAD", + expr.PayloadWrite: "WRITE", +} + +var payloadBaseNames = map[expr.PayloadBase]string{ + expr.PayloadBaseLLHeader: "ll-header", + expr.PayloadBaseNetworkHeader: "network-header", + expr.PayloadBaseTransportHeader: "transport-header", +} + +var packetTypeNames = map[int]string{ + 0 /* PACKET_HOST */ : "unicast", + 1 /* PACKET_BROADCAST */ : "broadcast", + 2 /* PACKET_MULTICAST */ : "multicast", +} + +var addrTypeFlagNames = map[xt.AddrTypeFlags]string{ + xt.AddrTypeUnspec: "unspec", + xt.AddrTypeUnicast: "unicast", + xt.AddrTypeLocal: "local", + xt.AddrTypeBroadcast: "broadcast", + xt.AddrTypeAnycast: "anycast", + xt.AddrTypeMulticast: "multicast", + xt.AddrTypeBlackhole: "blackhole", + xt.AddrTypeUnreachable: "unreachable", + xt.AddrTypeProhibit: "prohibit", + xt.AddrTypeThrow: "throw", + xt.AddrTypeNat: "nat", + xt.AddrTypeXresolve: "xresolve", +} diff --git a/vendor/tailscale.com/util/multierr/multierr.go b/vendor/tailscale.com/util/multierr/multierr.go index 8f40e01117..93ca068f56 100644 --- a/vendor/tailscale.com/util/multierr/multierr.go +++ b/vendor/tailscale.com/util/multierr/multierr.go @@ -7,9 +7,8 @@ package multierr import ( "errors" + "slices" "strings" - - "golang.org/x/exp/slices" ) // An Error represents multiple errors. diff --git a/vendor/tailscale.com/util/osdiag/internal/wsc/wsc_windows.go b/vendor/tailscale.com/util/osdiag/internal/wsc/wsc_windows.go new file mode 100644 index 0000000000..6fb6e54002 --- /dev/null +++ b/vendor/tailscale.com/util/osdiag/internal/wsc/wsc_windows.go @@ -0,0 +1,313 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by 'go generate'; DO NOT EDIT. + +package wsc + +import ( + "runtime" + "syscall" + "unsafe" + + "github.com/dblohm7/wingoes" + "github.com/dblohm7/wingoes/com" + "github.com/dblohm7/wingoes/com/automation" +) + +var ( + CLSID_WSCProductList = &com.CLSID{0x17072F7B, 0x9ABE, 0x4A74, [8]byte{0xA2, 0x61, 0x1E, 0xB7, 0x6B, 0x55, 0x10, 0x7A}} +) + +var ( + IID_IWSCProductList = &com.IID{0x722A338C, 0x6E8E, 0x4E72, [8]byte{0xAC, 0x27, 0x14, 0x17, 0xFB, 0x0C, 0x81, 0xC2}} + IID_IWscProduct = &com.IID{0x8C38232E, 0x3A45, 0x4A27, [8]byte{0x92, 0xB0, 0x1A, 0x16, 0xA9, 0x75, 0xF6, 0x69}} +) + +type WSC_SECURITY_PRODUCT_STATE int32 + +const ( + WSC_SECURITY_PRODUCT_STATE_ON = WSC_SECURITY_PRODUCT_STATE(0) + WSC_SECURITY_PRODUCT_STATE_OFF = WSC_SECURITY_PRODUCT_STATE(1) + WSC_SECURITY_PRODUCT_STATE_SNOOZED = WSC_SECURITY_PRODUCT_STATE(2) + WSC_SECURITY_PRODUCT_STATE_EXPIRED = WSC_SECURITY_PRODUCT_STATE(3) +) + +type WSC_SECURITY_SIGNATURE_STATUS int32 + +const ( + WSC_SECURITY_PRODUCT_OUT_OF_DATE = WSC_SECURITY_SIGNATURE_STATUS(0) + WSC_SECURITY_PRODUCT_UP_TO_DATE = WSC_SECURITY_SIGNATURE_STATUS(1) +) + +type WSC_SECURITY_PROVIDER int32 + +const ( + WSC_SECURITY_PROVIDER_FIREWALL = WSC_SECURITY_PROVIDER(1) + WSC_SECURITY_PROVIDER_AUTOUPDATE_SETTINGS = WSC_SECURITY_PROVIDER(2) + WSC_SECURITY_PROVIDER_ANTIVIRUS = WSC_SECURITY_PROVIDER(4) + WSC_SECURITY_PROVIDER_ANTISPYWARE = WSC_SECURITY_PROVIDER(8) + WSC_SECURITY_PROVIDER_INTERNET_SETTINGS = WSC_SECURITY_PROVIDER(16) + WSC_SECURITY_PROVIDER_USER_ACCOUNT_CONTROL = WSC_SECURITY_PROVIDER(32) + WSC_SECURITY_PROVIDER_SERVICE = WSC_SECURITY_PROVIDER(64) + WSC_SECURITY_PROVIDER_NONE = WSC_SECURITY_PROVIDER(0) + WSC_SECURITY_PROVIDER_ALL = WSC_SECURITY_PROVIDER(127) +) + +type SECURITY_PRODUCT_TYPE int32 + +const ( + SECURITY_PRODUCT_TYPE_ANTIVIRUS = SECURITY_PRODUCT_TYPE(0) + SECURITY_PRODUCT_TYPE_FIREWALL = SECURITY_PRODUCT_TYPE(1) + SECURITY_PRODUCT_TYPE_ANTISPYWARE = SECURITY_PRODUCT_TYPE(2) +) + +type IWscProductABI struct { + com.IUnknownABI // Technically IDispatch, but we're bypassing all of that atm +} + +func (abi *IWscProductABI) GetProductName() (pVal string, err error) { + var t0 automation.BSTR + + method := unsafe.Slice(abi.Vtbl, 14)[7] + hr, _, _ := syscall.SyscallN(method, uintptr(unsafe.Pointer(abi)), uintptr(unsafe.Pointer(&t0))) + if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(hr)); !e.IsOK() { + err = e + if e.Failed() { + return + } + } + + pVal = t0.String() + t0.Close() + return +} + +func (abi *IWscProductABI) GetProductState() (val WSC_SECURITY_PRODUCT_STATE, err error) { + method := unsafe.Slice(abi.Vtbl, 14)[8] + hr, _, _ := syscall.SyscallN(method, uintptr(unsafe.Pointer(abi)), uintptr(unsafe.Pointer(&val))) + if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(hr)); !e.IsOK() { + err = e + } + return +} + +func (abi *IWscProductABI) GetSignatureStatus() (val WSC_SECURITY_SIGNATURE_STATUS, err error) { + method := unsafe.Slice(abi.Vtbl, 14)[9] + hr, _, _ := syscall.SyscallN(method, uintptr(unsafe.Pointer(abi)), uintptr(unsafe.Pointer(&val))) + if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(hr)); !e.IsOK() { + err = e + } + return +} + +func (abi *IWscProductABI) GetRemediationPath() (pVal string, err error) { + var t0 automation.BSTR + + method := unsafe.Slice(abi.Vtbl, 14)[10] + hr, _, _ := syscall.SyscallN(method, uintptr(unsafe.Pointer(abi)), uintptr(unsafe.Pointer(&t0))) + if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(hr)); !e.IsOK() { + err = e + if e.Failed() { + return + } + } + + pVal = t0.String() + t0.Close() + return +} + +func (abi *IWscProductABI) GetProductStateTimestamp() (pVal string, err error) { + var t0 automation.BSTR + + method := unsafe.Slice(abi.Vtbl, 14)[11] + hr, _, _ := syscall.SyscallN(method, uintptr(unsafe.Pointer(abi)), uintptr(unsafe.Pointer(&t0))) + if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(hr)); !e.IsOK() { + err = e + if e.Failed() { + return + } + } + + pVal = t0.String() + t0.Close() + return +} + +func (abi *IWscProductABI) GetProductGuid() (pVal string, err error) { + var t0 automation.BSTR + + method := unsafe.Slice(abi.Vtbl, 14)[12] + hr, _, _ := syscall.SyscallN(method, uintptr(unsafe.Pointer(abi)), uintptr(unsafe.Pointer(&t0))) + if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(hr)); !e.IsOK() { + err = e + if e.Failed() { + return + } + } + + pVal = t0.String() + t0.Close() + return +} + +func (abi *IWscProductABI) GetProductIsDefault() (pVal bool, err error) { + var t0 int32 + + method := unsafe.Slice(abi.Vtbl, 14)[13] + hr, _, _ := syscall.SyscallN(method, uintptr(unsafe.Pointer(abi)), uintptr(unsafe.Pointer(&t0))) + if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(hr)); !e.IsOK() { + err = e + if e.Failed() { + return + } + } + + pVal = t0 != 0 + return +} + +type WscProduct struct { + com.GenericObject[IWscProductABI] +} + +func (o WscProduct) GetProductName() (pVal string, err error) { + p := *(o.Pp) + return p.GetProductName() +} + +func (o WscProduct) GetProductState() (val WSC_SECURITY_PRODUCT_STATE, err error) { + p := *(o.Pp) + return p.GetProductState() +} + +func (o WscProduct) GetSignatureStatus() (val WSC_SECURITY_SIGNATURE_STATUS, err error) { + p := *(o.Pp) + return p.GetSignatureStatus() +} + +func (o WscProduct) GetRemediationPath() (pVal string, err error) { + p := *(o.Pp) + return p.GetRemediationPath() +} + +func (o WscProduct) GetProductStateTimestamp() (pVal string, err error) { + p := *(o.Pp) + return p.GetProductStateTimestamp() +} + +func (o WscProduct) GetProductGuid() (pVal string, err error) { + p := *(o.Pp) + return p.GetProductGuid() +} + +func (o WscProduct) GetProductIsDefault() (pVal bool, err error) { + p := *(o.Pp) + return p.GetProductIsDefault() +} + +func (o WscProduct) IID() *com.IID { + return IID_IWscProduct +} + +func (o WscProduct) Make(r com.ABIReceiver) any { + if r == nil { + return WscProduct{} + } + + runtime.SetFinalizer(r, com.ReleaseABI) + + pp := (**IWscProductABI)(unsafe.Pointer(r)) + return WscProduct{com.GenericObject[IWscProductABI]{Pp: pp}} +} + +func (o WscProduct) MakeFromKnownABI(r **IWscProductABI) WscProduct { + if r == nil { + return WscProduct{} + } + + runtime.SetFinalizer(r, func(r **IWscProductABI) { (*r).Release() }) + return WscProduct{com.GenericObject[IWscProductABI]{Pp: r}} +} + +func (o WscProduct) UnsafeUnwrap() *IWscProductABI { + return *(o.Pp) +} + +type IWSCProductListABI struct { + com.IUnknownABI // Technically IDispatch, but we're bypassing all of that atm +} + +func (abi *IWSCProductListABI) Initialize(provider WSC_SECURITY_PROVIDER) (err error) { + method := unsafe.Slice(abi.Vtbl, 10)[7] + hr, _, _ := syscall.SyscallN(method, uintptr(unsafe.Pointer(abi)), uintptr(provider)) + if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(hr)); !e.IsOK() { + err = e + } + return +} + +func (abi *IWSCProductListABI) GetCount() (val int32, err error) { + method := unsafe.Slice(abi.Vtbl, 10)[8] + hr, _, _ := syscall.SyscallN(method, uintptr(unsafe.Pointer(abi)), uintptr(unsafe.Pointer(&val))) + if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(hr)); !e.IsOK() { + err = e + } + return +} + +func (abi *IWSCProductListABI) GetItem(index uint32) (val WscProduct, err error) { + var t0 *IWscProductABI + + method := unsafe.Slice(abi.Vtbl, 10)[9] + hr, _, _ := syscall.SyscallN(method, uintptr(unsafe.Pointer(abi)), uintptr(index), uintptr(unsafe.Pointer(&t0))) + if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(hr)); !e.IsOK() { + err = e + if e.Failed() { + return + } + } + + var r0 WscProduct + val = r0.MakeFromKnownABI(&t0) + return +} + +type WSCProductList struct { + com.GenericObject[IWSCProductListABI] +} + +func (o WSCProductList) Initialize(provider WSC_SECURITY_PROVIDER) (err error) { + p := *(o.Pp) + return p.Initialize(provider) +} + +func (o WSCProductList) GetCount() (val int32, err error) { + p := *(o.Pp) + return p.GetCount() +} + +func (o WSCProductList) GetItem(index uint32) (val WscProduct, err error) { + p := *(o.Pp) + return p.GetItem(index) +} + +func (o WSCProductList) IID() *com.IID { + return IID_IWSCProductList +} + +func (o WSCProductList) Make(r com.ABIReceiver) any { + if r == nil { + return WSCProductList{} + } + + runtime.SetFinalizer(r, com.ReleaseABI) + + pp := (**IWSCProductListABI)(unsafe.Pointer(r)) + return WSCProductList{com.GenericObject[IWSCProductListABI]{Pp: pp}} +} + +func (o WSCProductList) UnsafeUnwrap() *IWSCProductListABI { + return *(o.Pp) +} diff --git a/vendor/tailscale.com/util/osdiag/mksyscall.go b/vendor/tailscale.com/util/osdiag/mksyscall.go new file mode 100644 index 0000000000..bcbe113b05 --- /dev/null +++ b/vendor/tailscale.com/util/osdiag/mksyscall.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package osdiag + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go +//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go + +//sys globalMemoryStatusEx(memStatus *_MEMORYSTATUSEX) (err error) [int32(failretval)==0] = kernel32.GlobalMemoryStatusEx +//sys regEnumValue(key registry.Key, index uint32, valueName *uint16, valueNameLen *uint32, reserved *uint32, valueType *uint32, pData *byte, cbData *uint32) (ret error) [failretval!=0] = advapi32.RegEnumValueW +//sys wscEnumProtocols(iProtocols *int32, protocolBuffer *wsaProtocolInfo, bufLen *uint32, errno *int32) (ret int32) = ws2_32.WSCEnumProtocols +//sys wscGetProviderInfo(providerId *windows.GUID, infoType _WSC_PROVIDER_INFO_TYPE, info unsafe.Pointer, infoSize *uintptr, flags uint32, errno *int32) (ret int32) = ws2_32.WSCGetProviderInfo +//sys wscGetProviderPath(providerId *windows.GUID, providerDllPath *uint16, providerDllPathLen *int32, errno *int32) (ret int32) = ws2_32.WSCGetProviderPath diff --git a/vendor/tailscale.com/util/osdiag/osdiag.go b/vendor/tailscale.com/util/osdiag/osdiag.go new file mode 100644 index 0000000000..df1bcb3625 --- /dev/null +++ b/vendor/tailscale.com/util/osdiag/osdiag.go @@ -0,0 +1,23 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package osdiag provides loggers for OS-specific diagnostic information. +package osdiag + +import "tailscale.com/types/logger" + +// LogSupportInfoReason is an enumeration indicating the reason for logging +// support info. +type LogSupportInfoReason int + +const ( + LogSupportInfoReasonStartup LogSupportInfoReason = iota + 1 // tailscaled is starting up. + LogSupportInfoReasonBugReport // a bugreport is in the process of being gathered. +) + +// LogSupportInfo obtains OS-specific diagnostic information useful for +// troubleshooting and support, and writes it to logf. The reason argument is +// useful for governing the verbosity of this function's output. +func LogSupportInfo(logf logger.Logf, reason LogSupportInfoReason) { + logSupportInfo(logf, reason) +} diff --git a/vendor/tailscale.com/util/osdiag/osdiag_notwindows.go b/vendor/tailscale.com/util/osdiag/osdiag_notwindows.go new file mode 100644 index 0000000000..da172b9676 --- /dev/null +++ b/vendor/tailscale.com/util/osdiag/osdiag_notwindows.go @@ -0,0 +1,11 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package osdiag + +import "tailscale.com/types/logger" + +func logSupportInfo(logger.Logf, LogSupportInfoReason) { +} diff --git a/vendor/tailscale.com/util/osdiag/osdiag_windows.go b/vendor/tailscale.com/util/osdiag/osdiag_windows.go new file mode 100644 index 0000000000..992f3589dd --- /dev/null +++ b/vendor/tailscale.com/util/osdiag/osdiag_windows.go @@ -0,0 +1,668 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package osdiag + +import ( + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "path/filepath" + "strings" + "unicode/utf16" + "unsafe" + + "github.com/dblohm7/wingoes/com" + "github.com/dblohm7/wingoes/pe" + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" + "tailscale.com/types/logger" + "tailscale.com/util/osdiag/internal/wsc" + "tailscale.com/util/winutil" + "tailscale.com/util/winutil/authenticode" +) + +var ( + errUnexpectedResult = errors.New("API call returned an unexpected value") +) + +const ( + maxBinaryValueLen = 128 // we'll truncate any binary values longer than this + maxRegValueNameLen = 16384 // maximum length supported by Windows + 1 + initialValueBufLen = 80 // large enough to contain a stringified GUID encoded as UTF-16 +) + +func logSupportInfo(logf logger.Logf, reason LogSupportInfoReason) { + var b strings.Builder + if err := getSupportInfo(&b, reason); err != nil { + logf("error encoding support info: %v", err) + return + } + logf("%s", b.String()) +} + +const ( + supportInfoKeyModules = "modules" + supportInfoKeyPageFile = "pageFile" + supportInfoKeyRegistry = "registry" + supportInfoKeySecurity = "securitySoftware" + supportInfoKeyWinsockLSP = "winsockLSP" +) + +func getSupportInfo(w io.Writer, reason LogSupportInfoReason) error { + output := make(map[string]any) + + regInfo, err := getRegistrySupportInfo(registry.LOCAL_MACHINE, []string{`SOFTWARE\Policies\Tailscale`, winutil.RegBase}) + if err == nil { + output[supportInfoKeyRegistry] = regInfo + } else { + output[supportInfoKeyRegistry] = err + } + + pageFileInfo, err := getPageFileInfo() + if err == nil { + output[supportInfoKeyPageFile] = pageFileInfo + } else { + output[supportInfoKeyPageFile] = err + } + + if reason == LogSupportInfoReasonBugReport { + modInfo, err := getModuleInfo() + if err == nil { + output[supportInfoKeyModules] = modInfo + } else { + output[supportInfoKeyModules] = err + } + + output[supportInfoKeySecurity] = getSecurityInfo() + + lspInfo, err := getWinsockLSPInfo() + if err == nil { + output[supportInfoKeyWinsockLSP] = lspInfo + } else { + output[supportInfoKeyWinsockLSP] = err + } + } + + enc := json.NewEncoder(w) + return enc.Encode(output) +} + +type getRegistrySupportInfoBufs struct { + nameBuf []uint16 + valueBuf []byte +} + +func getRegistrySupportInfo(root registry.Key, subKeys []string) (map[string]any, error) { + bufs := getRegistrySupportInfoBufs{ + nameBuf: make([]uint16, maxRegValueNameLen), + valueBuf: make([]byte, initialValueBufLen), + } + + output := make(map[string]any) + + for _, subKey := range subKeys { + if err := getRegSubKey(root, subKey, 5, &bufs, output); err != nil && !errors.Is(err, registry.ErrNotExist) { + return nil, fmt.Errorf("getRegistrySupportInfo: %w", err) + } + } + + return output, nil +} + +func keyString(key registry.Key, subKey string) string { + var keyStr string + switch key { + case registry.CLASSES_ROOT: + keyStr = `HKCR\` + case registry.CURRENT_USER: + keyStr = `HKCU\` + case registry.LOCAL_MACHINE: + keyStr = `HKLM\` + case registry.USERS: + keyStr = `HKU\` + case registry.CURRENT_CONFIG: + keyStr = `HKCC\` + case registry.PERFORMANCE_DATA: + keyStr = `HKPD\` + default: + } + + return keyStr + subKey +} + +func getRegSubKey(key registry.Key, subKey string, recursionLimit int, bufs *getRegistrySupportInfoBufs, output map[string]any) error { + keyStr := keyString(key, subKey) + k, err := registry.OpenKey(key, subKey, registry.READ) + if err != nil { + return fmt.Errorf("opening %q: %w", keyStr, err) + } + defer k.Close() + + kv := make(map[string]any) + index := uint32(0) + +loopValues: + for { + nbuf := bufs.nameBuf + nameLen := uint32(len(nbuf)) + valueType := uint32(0) + vbuf := bufs.valueBuf + valueLen := uint32(len(vbuf)) + + err := regEnumValue(k, index, &nbuf[0], &nameLen, nil, &valueType, &vbuf[0], &valueLen) + switch err { + case windows.ERROR_NO_MORE_ITEMS: + break loopValues + case windows.ERROR_MORE_DATA: + bufs.valueBuf = make([]byte, valueLen) + continue + case nil: + default: + return fmt.Errorf("regEnumValue: %w", err) + } + + var value any + + switch valueType { + case registry.SZ, registry.EXPAND_SZ: + value = windows.UTF16PtrToString((*uint16)(unsafe.Pointer(&vbuf[0]))) + case registry.BINARY: + if valueLen > maxBinaryValueLen { + valueLen = maxBinaryValueLen + } + value = append([]byte{}, vbuf[:valueLen]...) + case registry.DWORD: + value = binary.LittleEndian.Uint32(vbuf[:4]) + case registry.MULTI_SZ: + // Adapted from x/sys/windows/registry/(Key).GetStringsValue + p := (*[1 << 29]uint16)(unsafe.Pointer(&vbuf[0]))[: valueLen/2 : valueLen/2] + var strs []string + if len(p) > 0 { + if p[len(p)-1] == 0 { + p = p[:len(p)-1] + } + strs = make([]string, 0, 5) + from := 0 + for i, c := range p { + if c == 0 { + strs = append(strs, string(utf16.Decode(p[from:i]))) + from = i + 1 + } + } + } + value = strs + case registry.QWORD: + value = binary.LittleEndian.Uint64(vbuf[:8]) + default: + value = fmt.Sprintf("", valueType) + } + + kv[windows.UTF16PtrToString(&nbuf[0])] = value + index++ + } + + if recursionLimit > 0 { + if sks, err := k.ReadSubKeyNames(0); err == nil { + for _, sk := range sks { + if err := getRegSubKey(k, sk, recursionLimit-1, bufs, kv); err != nil { + return err + } + } + } + } + + output[keyStr] = kv + return nil +} + +type moduleInfo struct { + path string `json:"-"` // internal use only + BaseAddress uintptr `json:"baseAddress"` + Size uint32 `json:"size"` + DebugInfo map[string]string `json:"debugInfo,omitempty"` // map for JSON marshaling purposes + DebugInfoErr error `json:"debugInfoErr,omitempty"` + Signature map[string]string `json:"signature,omitempty"` // map for JSON marshaling purposes + SignatureErr error `json:"signatureErr,omitempty"` + VersionInfo map[string]string `json:"versionInfo,omitempty"` // map for JSON marshaling purposes + VersionErr error `json:"versionErr,omitempty"` +} + +func (mi *moduleInfo) setVersionInfo() { + vi, err := pe.NewVersionInfo(mi.path) + if err != nil { + if !errors.Is(err, pe.ErrNotPresent) { + mi.VersionErr = err + } + return + } + + info := map[string]string{ + "": vi.VersionNumber().String(), + } + + ci, err := vi.Field("CompanyName") + if err == nil { + info["companyName"] = ci + } + + mi.VersionInfo = info +} + +var errAssertingType = errors.New("asserting DataDirectory type") + +func (mi *moduleInfo) setDebugInfo() { + pem, err := pe.NewPEFromBaseAddressAndSize(mi.BaseAddress, mi.Size) + if err != nil { + mi.DebugInfoErr = err + return + } + defer pem.Close() + + debugDirAny, err := pem.DataDirectoryEntry(pe.IMAGE_DIRECTORY_ENTRY_DEBUG) + if err != nil { + if !errors.Is(err, pe.ErrNotPresent) { + mi.DebugInfoErr = err + } + return + } + + debugDir, ok := debugDirAny.([]pe.IMAGE_DEBUG_DIRECTORY) + if !ok { + mi.DebugInfoErr = errAssertingType + return + } + + for _, dde := range debugDir { + if dde.Type != pe.IMAGE_DEBUG_TYPE_CODEVIEW { + continue + } + + cv, err := pem.ExtractCodeViewInfo(dde) + if err == nil { + mi.DebugInfo = map[string]string{ + "id": cv.String(), + "pdb": strings.ToLower(filepath.Base(cv.PDBPath)), + } + } else { + mi.DebugInfoErr = err + } + + return + } +} + +func (mi *moduleInfo) setAuthenticodeInfo() { + certSubject, provenance, err := authenticode.QueryCertSubject(mi.path) + if err != nil { + if !errors.Is(err, authenticode.ErrSigNotFound) { + mi.SignatureErr = err + } + return + } + + sigInfo := map[string]string{ + "subject": certSubject, + } + + switch provenance { + case authenticode.SigProvEmbedded: + sigInfo["provenance"] = "embedded" + case authenticode.SigProvCatalog: + sigInfo["provenance"] = "catalog" + default: + } + + mi.Signature = sigInfo +} + +func getModuleInfo() (map[string]moduleInfo, error) { + // Take a snapshot of all modules currently loaded into the current process + snap, err := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPMODULE, 0) + if err != nil { + return nil, err + } + defer windows.CloseHandle(snap) + + result := make(map[string]moduleInfo) + me := windows.ModuleEntry32{ + Size: uint32(unsafe.Sizeof(windows.ModuleEntry32{})), + } + + // Now walk the list + for merr := windows.Module32First(snap, &me); merr == nil; merr = windows.Module32Next(snap, &me) { + name := strings.ToLower(windows.UTF16ToString(me.Module[:])) + path := windows.UTF16ToString(me.ExePath[:]) + base := me.ModBaseAddr + size := me.ModBaseSize + + entry := moduleInfo{ + path: path, + BaseAddress: base, + Size: size, + } + + entry.setVersionInfo() + entry.setDebugInfo() + entry.setAuthenticodeInfo() + + result[name] = entry + } + + return result, nil +} + +type _WSC_PROVIDER_INFO_TYPE int32 + +const ( + providerInfoLspCategories _WSC_PROVIDER_INFO_TYPE = 0 +) + +const ( + _SOCKET_ERROR = -1 +) + +// Note that wsaProtocolInfo needs to be identical to windows.WSAProtocolInfo; +// the purpose of this type is to have the ability to use it as a reciever in +// the path and categoryFlags funcs defined below. +type wsaProtocolInfo windows.WSAProtocolInfo + +func (pi *wsaProtocolInfo) path() (string, error) { + var errno int32 + var buf [windows.MAX_PATH]uint16 + bufCount := int32(len(buf)) + ret := wscGetProviderPath(&pi.ProviderId, &buf[0], &bufCount, &errno) + if ret == _SOCKET_ERROR { + return "", windows.Errno(errno) + } + if ret != 0 { + return "", errUnexpectedResult + } + + return windows.UTF16ToString(buf[:bufCount]), nil +} + +func (pi *wsaProtocolInfo) categoryFlags() (uint32, error) { + var errno int32 + var result uint32 + bufLen := uintptr(unsafe.Sizeof(result)) + ret := wscGetProviderInfo(&pi.ProviderId, providerInfoLspCategories, unsafe.Pointer(&result), &bufLen, 0, &errno) + if ret == _SOCKET_ERROR { + return 0, windows.Errno(errno) + } + if ret != 0 { + return 0, errUnexpectedResult + } + + return result, nil +} + +type wsaProtocolInfoOutput struct { + Description string `json:"description,omitempty"` + Version int32 `json:"version"` + AddressFamily int32 `json:"addressFamily"` + SocketType int32 `json:"socketType"` + Protocol int32 `json:"protocol"` + ServiceFlags1 string `json:"serviceFlags1"` + ProviderFlags string `json:"providerFlags"` + Path string `json:"path,omitempty"` + PathErr error `json:"pathErr,omitempty"` + Category string `json:"category,omitempty"` + CategoryErr error `json:"categoryErr,omitempty"` + BaseProviderID string `json:"baseProviderID,omitempty"` + LayerProviderID string `json:"layerProviderID,omitempty"` + Chain []uint32 `json:"chain,omitempty"` +} + +func getWinsockLSPInfo() (map[uint32]wsaProtocolInfoOutput, error) { + protocols, err := enumWinsockProtocols() + if err != nil { + return nil, err + } + + result := make(map[uint32]wsaProtocolInfoOutput, len(protocols)) + for _, p := range protocols { + v := wsaProtocolInfoOutput{ + Description: windows.UTF16ToString(p.ProtocolName[:]), + Version: p.Version, + AddressFamily: p.AddressFamily, + SocketType: p.SocketType, + Protocol: p.Protocol, + ServiceFlags1: fmt.Sprintf("0x%08X", p.ServiceFlags1), // Serializing as hex string to make the flags easier to decode by human inspection + ProviderFlags: fmt.Sprintf("0x%08X", p.ProviderFlags), + } + + switch p.ProtocolChain.ChainLen { + case windows.BASE_PROTOCOL: + v.BaseProviderID = p.ProviderId.String() + case windows.LAYERED_PROTOCOL: + v.LayerProviderID = p.ProviderId.String() + default: + v.Chain = p.ProtocolChain.ChainEntries[:p.ProtocolChain.ChainLen] + } + + // Queries that are only valid for base and layered protocols (not chains) + if v.Chain == nil { + path, err := p.path() + if err == nil { + v.Path = strings.ToLower(path) + } else { + v.PathErr = err + } + + category, err := p.categoryFlags() + if err == nil { + v.Category = fmt.Sprintf("0x%08X", category) + } else if !errors.Is(err, windows.WSAEINVALIDPROVIDER) { + // WSAEINVALIDPROVIDER == "no category info found", so we only log + // errors other than that one. + v.CategoryErr = err + } + } + + // Chains reference other providers using catalog entry IDs, so we use that + // value as the key in our map. + result[p.CatalogEntryId] = v + } + + return result, nil +} + +func enumWinsockProtocols() ([]wsaProtocolInfo, error) { + // Get the required size + var errno int32 + var bytesReqd uint32 + ret := wscEnumProtocols(nil, nil, &bytesReqd, &errno) + if ret != _SOCKET_ERROR { + return nil, errUnexpectedResult + } + if e := windows.Errno(errno); e != windows.WSAENOBUFS { + return nil, e + } + + // Allocate + szEntry := uint32(unsafe.Sizeof(wsaProtocolInfo{})) + buf := make([]wsaProtocolInfo, bytesReqd/szEntry) + + // Now do the query for real + bufLen := uint32(len(buf)) * szEntry + ret = wscEnumProtocols(nil, &buf[0], &bufLen, &errno) + if ret == _SOCKET_ERROR { + return nil, windows.Errno(errno) + } + + return buf, nil +} + +type providerKey struct { + provType wsc.WSC_SECURITY_PROVIDER + provKey string +} + +var providerKeys = []providerKey{ + providerKey{ + wsc.WSC_SECURITY_PROVIDER_ANTIVIRUS, + "av", + }, + providerKey{ + wsc.WSC_SECURITY_PROVIDER_ANTISPYWARE, + "antispy", + }, + providerKey{ + wsc.WSC_SECURITY_PROVIDER_FIREWALL, + "firewall", + }, +} + +const ( + maxProvCount = 100 +) + +type secProductInfo struct { + Name string `json:"name,omitempty"` + NameErr error `json:"nameErr,omitempty"` + State string `json:"state,omitempty"` + StateErr error `json:"stateErr,omitempty"` +} + +func getSecurityInfo() map[string]any { + result := make(map[string]any) + + for _, prov := range providerKeys { + // Note that we need to obtain a new product list for each provider type; + // the docs clearly state that we cannot reuse objects. + productList, err := com.CreateInstance[wsc.WSCProductList](wsc.CLSID_WSCProductList) + if err != nil { + result[prov.provKey] = err + continue + } + + err = productList.Initialize(prov.provType) + if err != nil { + result[prov.provKey] = err + continue + } + + n, err := productList.GetCount() + if err != nil { + result[prov.provKey] = err + continue + } + if n == 0 { + continue + } + + n = min(n, maxProvCount) + values := make([]any, 0, n) + + for i := int32(0); i < n; i++ { + product, err := productList.GetItem(uint32(i)) + if err != nil { + values = append(values, err) + continue + } + + var value secProductInfo + + value.Name, err = product.GetProductName() + if err != nil { + value.NameErr = err + } + + state, err := product.GetProductState() + if err == nil { + switch state { + case wsc.WSC_SECURITY_PRODUCT_STATE_ON: + value.State = "on" + case wsc.WSC_SECURITY_PRODUCT_STATE_OFF: + value.State = "off" + case wsc.WSC_SECURITY_PRODUCT_STATE_SNOOZED: + value.State = "snoozed" + case wsc.WSC_SECURITY_PRODUCT_STATE_EXPIRED: + value.State = "expired" + default: + value.State = fmt.Sprintf("", state) + } + } else { + value.StateErr = err + } + + values = append(values, value) + } + + result[prov.provKey] = values + } + + return result +} + +type _MEMORYSTATUSEX struct { + Length uint32 + MemoryLoad uint32 + TotalPhys uint64 + AvailPhys uint64 + TotalPageFile uint64 + AvailPageFile uint64 + TotalVirtual uint64 + AvailVirtual uint64 + AvailExtendedVirtual uint64 +} + +func getPageFileInfo() (map[string]any, error) { + memStatus := _MEMORYSTATUSEX{ + Length: uint32(unsafe.Sizeof(_MEMORYSTATUSEX{})), + } + if err := globalMemoryStatusEx(&memStatus); err != nil { + return nil, err + } + + result := map[string]any{ + "bytesAvailable": memStatus.AvailPageFile, + "bytesTotal": memStatus.TotalPageFile, + } + + if entries, err := getEffectivePageFileValue(); err == nil { + // autoManaged is set to true when there is at least one page file that + // is automatically managed. + autoManaged := false + + // If there is only one entry that consists of only one part, then + // the page files are 100% managed by the system. + // If there are multiple entries, then each one must be checked. + // Each entry then consists of three components, deliminated by spaces. + // If the latter two components are both "0", then that entry is auto-managed. + for _, entry := range entries { + if parts := strings.Split(entry, " "); (len(parts) == 1 && len(entries) == 1) || + (len(parts) == 3 && parts[1] == "0" && parts[2] == "0") { + autoManaged = true + break + } + } + + result["autoManaged"] = autoManaged + } + + return result, nil +} + +func getEffectivePageFileValue() ([]string, error) { + const subKey = `SYSTEM\CurrentControlSet\Control\Session Manager\Memory Management` + key, err := registry.OpenKey(registry.LOCAL_MACHINE, subKey, registry.QUERY_VALUE) + if err != nil { + return nil, err + } + defer key.Close() + + // Rare but possible case: the user has updated their page file config but + // they haven't yet rebooted for the change to take effect. This is the + // current setting that the machine is still operating with. + if entries, _, err := key.GetStringsValue("ExistingPageFiles"); err == nil { + return entries, nil + } + + // Otherwise we use this value (yes, the above value uses "Page" and this one uses "Paging"). + entries, _, err := key.GetStringsValue("PagingFiles") + return entries, err +} diff --git a/vendor/tailscale.com/util/osdiag/zsyscall_windows.go b/vendor/tailscale.com/util/osdiag/zsyscall_windows.go new file mode 100644 index 0000000000..ab0d18d3f9 --- /dev/null +++ b/vendor/tailscale.com/util/osdiag/zsyscall_windows.go @@ -0,0 +1,85 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package osdiag + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + modws2_32 = windows.NewLazySystemDLL("ws2_32.dll") + + procRegEnumValueW = modadvapi32.NewProc("RegEnumValueW") + procGlobalMemoryStatusEx = modkernel32.NewProc("GlobalMemoryStatusEx") + procWSCEnumProtocols = modws2_32.NewProc("WSCEnumProtocols") + procWSCGetProviderInfo = modws2_32.NewProc("WSCGetProviderInfo") + procWSCGetProviderPath = modws2_32.NewProc("WSCGetProviderPath") +) + +func regEnumValue(key registry.Key, index uint32, valueName *uint16, valueNameLen *uint32, reserved *uint32, valueType *uint32, pData *byte, cbData *uint32) (ret error) { + r0, _, _ := syscall.Syscall9(procRegEnumValueW.Addr(), 8, uintptr(key), uintptr(index), uintptr(unsafe.Pointer(valueName)), uintptr(unsafe.Pointer(valueNameLen)), uintptr(unsafe.Pointer(reserved)), uintptr(unsafe.Pointer(valueType)), uintptr(unsafe.Pointer(pData)), uintptr(unsafe.Pointer(cbData)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func globalMemoryStatusEx(memStatus *_MEMORYSTATUSEX) (err error) { + r1, _, e1 := syscall.Syscall(procGlobalMemoryStatusEx.Addr(), 1, uintptr(unsafe.Pointer(memStatus)), 0, 0) + if int32(r1) == 0 { + err = errnoErr(e1) + } + return +} + +func wscEnumProtocols(iProtocols *int32, protocolBuffer *wsaProtocolInfo, bufLen *uint32, errno *int32) (ret int32) { + r0, _, _ := syscall.Syscall6(procWSCEnumProtocols.Addr(), 4, uintptr(unsafe.Pointer(iProtocols)), uintptr(unsafe.Pointer(protocolBuffer)), uintptr(unsafe.Pointer(bufLen)), uintptr(unsafe.Pointer(errno)), 0, 0) + ret = int32(r0) + return +} + +func wscGetProviderInfo(providerId *windows.GUID, infoType _WSC_PROVIDER_INFO_TYPE, info unsafe.Pointer, infoSize *uintptr, flags uint32, errno *int32) (ret int32) { + r0, _, _ := syscall.Syscall6(procWSCGetProviderInfo.Addr(), 6, uintptr(unsafe.Pointer(providerId)), uintptr(infoType), uintptr(info), uintptr(unsafe.Pointer(infoSize)), uintptr(flags), uintptr(unsafe.Pointer(errno))) + ret = int32(r0) + return +} + +func wscGetProviderPath(providerId *windows.GUID, providerDllPath *uint16, providerDllPathLen *int32, errno *int32) (ret int32) { + r0, _, _ := syscall.Syscall6(procWSCGetProviderPath.Addr(), 4, uintptr(unsafe.Pointer(providerId)), uintptr(unsafe.Pointer(providerDllPath)), uintptr(unsafe.Pointer(providerDllPathLen)), uintptr(unsafe.Pointer(errno)), 0, 0) + ret = int32(r0) + return +} diff --git a/vendor/tailscale.com/util/rands/rands.go b/vendor/tailscale.com/util/rands/rands.go new file mode 100644 index 0000000000..d83e1e5589 --- /dev/null +++ b/vendor/tailscale.com/util/rands/rands.go @@ -0,0 +1,25 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package rands contains utility functions for randomness. +package rands + +import ( + crand "crypto/rand" + "encoding/hex" +) + +// HexString returns a string of n cryptographically random lowercase +// hex characters. +// +// That is, HexString(3) returns something like "0fc", containing 12 +// bits of randomness. +func HexString(n int) string { + nb := n / 2 + if n%2 == 1 { + nb++ + } + b := make([]byte, nb) + crand.Read(b) + return hex.EncodeToString(b)[:n] +} diff --git a/vendor/tailscale.com/util/set/set.go b/vendor/tailscale.com/util/set/set.go index 6adb5182f5..e6f3ef1f02 100644 --- a/vendor/tailscale.com/util/set/set.go +++ b/vendor/tailscale.com/util/set/set.go @@ -10,6 +10,9 @@ type Set[T comparable] map[T]struct{} // Add adds e to the set. func (s Set[T]) Add(e T) { s[e] = struct{}{} } +// Delete removes e from the set. +func (s Set[T]) Delete(e T) { delete(s, e) } + // Contains reports whether s contains e. func (s Set[T]) Contains(e T) bool { _, ok := s[e] diff --git a/vendor/tailscale.com/util/set/slice.go b/vendor/tailscale.com/util/set/slice.go index 589b903df9..fe764b550f 100644 --- a/vendor/tailscale.com/util/set/slice.go +++ b/vendor/tailscale.com/util/set/slice.go @@ -4,7 +4,8 @@ package set import ( - "golang.org/x/exp/slices" + "slices" + "tailscale.com/types/views" ) @@ -19,6 +20,9 @@ type Slice[T comparable] struct { // The returned value is only valid until ss is modified again. func (ss *Slice[T]) Slice() views.Slice[T] { return views.SliceOf(ss.slice) } +// Len returns the number of elements in the set. +func (ss *Slice[T]) Len() int { return len(ss.slice) } + // Contains reports whether v is in the set. // The amortized cost is O(1). func (ss *Slice[T]) Contains(v T) bool { diff --git a/vendor/tailscale.com/util/testenv/testenv.go b/vendor/tailscale.com/util/testenv/testenv.go new file mode 100644 index 0000000000..12ada90030 --- /dev/null +++ b/vendor/tailscale.com/util/testenv/testenv.go @@ -0,0 +1,21 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package testenv provides utility functions for tests. It does not depend on +// the `testing` package to allow usage in non-test code. +package testenv + +import ( + "flag" + + "tailscale.com/types/lazy" +) + +var lazyInTest lazy.SyncValue[bool] + +// InTest reports whether the current binary is a test binary. +func InTest() bool { + return lazyInTest.Get(func() bool { + return flag.Lookup("test.v") != nil + }) +} diff --git a/vendor/tailscale.com/util/winutil/authenticode/authenticode_windows.go b/vendor/tailscale.com/util/winutil/authenticode/authenticode_windows.go new file mode 100644 index 0000000000..1ed0b7cfb9 --- /dev/null +++ b/vendor/tailscale.com/util/winutil/authenticode/authenticode_windows.go @@ -0,0 +1,518 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package authenticode + +import ( + "encoding/hex" + "errors" + "fmt" + "path/filepath" + "strings" + "unsafe" + + "github.com/dblohm7/wingoes" + "github.com/dblohm7/wingoes/pe" + "golang.org/x/sys/windows" +) + +var ( + // ErrSigNotFound is returned if no authenticode signature could be found. + ErrSigNotFound = errors.New("authenticode signature not found") + // ErrUnexpectedCertSubject is wrapped with the actual cert subject and + // returned when the binary is signed by a different subject than expected. + ErrUnexpectedCertSubject = errors.New("unexpected cert subject") + errCertSubjectNotFound = errors.New("cert subject not found") + errCertSubjectDecodeLenMismatch = errors.New("length mismatch while decoding cert subject") +) + +const ( + _CERT_STRONG_SIGN_OID_INFO_CHOICE = 2 + _CMSG_SIGNER_CERT_INFO_PARAM = 7 + _MSI_INVALID_HASH_IS_FATAL = 1 + _TRUST_E_NOSIGNATURE = wingoes.HRESULT(-((0x800B0100 ^ 0xFFFFFFFF) + 1)) +) + +// Verify performs authenticode verification on the file at path, and also +// ensures that expectedCertSubject matches the actual cert subject. path may +// point to either a PE binary or an MSI package. ErrSigNotFound is returned if +// no signature is found. +func Verify(path string, expectedCertSubject string) error { + path16, err := windows.UTF16PtrFromString(path) + if err != nil { + return err + } + + var subject string + if strings.EqualFold(filepath.Ext(path), ".msi") { + subject, err = verifyMSI(path16) + } else { + subject, _, err = queryPE(path16, true) + } + + if err != nil { + return err + } + + if subject != expectedCertSubject { + return fmt.Errorf("%w %q", ErrUnexpectedCertSubject, subject) + } + + return nil +} + +// SigProvenance indicates whether an authenticode signature was embedded within +// the file itself, or the signature applies to an associated catalog file. +type SigProvenance int + +const ( + SigProvUnknown = SigProvenance(iota) + SigProvEmbedded + SigProvCatalog +) + +// QueryCertSubject obtains the subject associated with the certificate used to +// sign the PE binary located at path. When err == nil, it also returns the +// provenance of that signature. ErrSigNotFound is returned if no signature +// is found. Note that this function does *not* validate the chain of trust; use +// Verify for that purpose! +func QueryCertSubject(path string) (certSubject string, provenance SigProvenance, err error) { + path16, err := windows.UTF16PtrFromString(path) + if err != nil { + return "", SigProvUnknown, err + } + + return queryPE(path16, false) +} + +func queryPE(utf16Path *uint16, verify bool) (string, SigProvenance, error) { + certSubject, err := queryEmbeddedCertSubject(utf16Path, verify) + + switch { + case err == ErrSigNotFound: + // Try looking for the signature in a catalog file. + default: + return certSubject, SigProvEmbedded, err + } + + certSubject, err = queryCatalogCertSubject(utf16Path, verify) + switch { + case err == ErrSigNotFound: + return "", SigProvUnknown, err + default: + return certSubject, SigProvCatalog, err + } +} + +// CertSubjectError is returned if a cert subject was successfully resolved but +// there was a problem encountered during its extraction. The Subject is +// provided for informational purposes but is not presumed to be accurate. +type CertSubjectError struct { + Err error // The error that occurred while extracting the cert subject. + Subject string // The (possibly invalid) cert subject that was extracted. +} + +func (e *CertSubjectError) Error() string { + if e == nil { + return "" + } + if e.Subject == "" { + return e.Err.Error() + } + return fmt.Sprintf("cert subject %q: %v", e.Subject, e.Err) +} + +func (e *CertSubjectError) Unwrap() error { + return e.Err +} + +func verifyMSI(path *uint16) (string, error) { + var certCtx *windows.CertContext + hr := msiGetFileSignatureInformation(path, _MSI_INVALID_HASH_IS_FATAL, &certCtx, nil, nil) + if e := wingoes.ErrorFromHRESULT(hr); e.Failed() { + if e == wingoes.ErrorFromHRESULT(_TRUST_E_NOSIGNATURE) { + return "", ErrSigNotFound + } + return "", e + } + defer windows.CertFreeCertificateContext(certCtx) + + return certSubjectFromCertContext(certCtx) +} + +func certSubjectFromCertContext(certCtx *windows.CertContext) (string, error) { + desiredLen := windows.CertGetNameString( + certCtx, + windows.CERT_NAME_SIMPLE_DISPLAY_TYPE, + 0, + nil, + nil, + 0, + ) + if desiredLen <= 1 { + return "", errCertSubjectNotFound + } + + buf := make([]uint16, desiredLen) + actualLen := windows.CertGetNameString( + certCtx, + windows.CERT_NAME_SIMPLE_DISPLAY_TYPE, + 0, + nil, + &buf[0], + desiredLen, + ) + if actualLen != desiredLen { + return "", errCertSubjectDecodeLenMismatch + } + + return windows.UTF16ToString(buf), nil +} + +type objectQuery struct { + certStore windows.Handle + cryptMsg windows.Handle + encodingType uint32 +} + +func newObjectQuery(utf16Path *uint16) (*objectQuery, error) { + var oq objectQuery + if err := windows.CryptQueryObject( + windows.CERT_QUERY_OBJECT_FILE, + unsafe.Pointer(utf16Path), + windows.CERT_QUERY_CONTENT_FLAG_PKCS7_SIGNED_EMBED, + windows.CERT_QUERY_FORMAT_FLAG_BINARY, + 0, + &oq.encodingType, + nil, + nil, + &oq.certStore, + &oq.cryptMsg, + nil, + ); err != nil { + return nil, err + } + + return &oq, nil +} + +func (oq *objectQuery) Close() error { + if oq.certStore != 0 { + if err := windows.CertCloseStore(oq.certStore, 0); err != nil { + return err + } + oq.certStore = 0 + } + + if oq.cryptMsg != 0 { + if err := cryptMsgClose(oq.cryptMsg); err != nil { + return err + } + oq.cryptMsg = 0 + } + + return nil +} + +func (oq *objectQuery) certSubject() (string, error) { + var certInfoLen uint32 + if err := cryptMsgGetParam( + oq.cryptMsg, + _CMSG_SIGNER_CERT_INFO_PARAM, + 0, + unsafe.Pointer(nil), + &certInfoLen, + ); err != nil { + return "", err + } + + buf := make([]byte, certInfoLen) + if err := cryptMsgGetParam( + oq.cryptMsg, + _CMSG_SIGNER_CERT_INFO_PARAM, + 0, + unsafe.Pointer(&buf[0]), + &certInfoLen, + ); err != nil { + return "", err + } + + certInfo := (*windows.CertInfo)(unsafe.Pointer(&buf[0])) + certCtx, err := windows.CertFindCertificateInStore( + oq.certStore, + oq.encodingType, + 0, + windows.CERT_FIND_SUBJECT_CERT, + unsafe.Pointer(certInfo), + nil, + ) + if err != nil { + return "", err + } + defer windows.CertFreeCertificateContext(certCtx) + + return certSubjectFromCertContext(certCtx) +} + +func extractCertBlob(hfile windows.Handle) ([]byte, error) { + pef, err := pe.NewPEFromFileHandle(hfile) + if err != nil { + return nil, err + } + defer pef.Close() + + certsAny, err := pef.DataDirectoryEntry(pe.IMAGE_DIRECTORY_ENTRY_SECURITY) + if err != nil { + if errors.Is(err, pe.ErrNotPresent) { + err = ErrSigNotFound + } + return nil, err + } + + certs, ok := certsAny.([]pe.AuthenticodeCert) + if !ok || len(certs) == 0 { + return nil, ErrSigNotFound + } + + for _, cert := range certs { + if cert.Revision() != pe.WIN_CERT_REVISION_2_0 || cert.Type() != pe.WIN_CERT_TYPE_PKCS_SIGNED_DATA { + continue + } + return cert.Data(), nil + } + + return nil, ErrSigNotFound +} + +type _HCRYPTPROV windows.Handle + +type _CRYPT_VERIFY_MESSAGE_PARA struct { + CBSize uint32 + MsgAndCertEncodingType uint32 + HCryptProv _HCRYPTPROV + FNGetSignerCertificate uintptr + GetArg uintptr + StrongSignPara *windows.CertStrongSignPara +} + +func querySubjectFromBlob(blob []byte) (string, error) { + para := _CRYPT_VERIFY_MESSAGE_PARA{ + CBSize: uint32(unsafe.Sizeof(_CRYPT_VERIFY_MESSAGE_PARA{})), + MsgAndCertEncodingType: windows.X509_ASN_ENCODING | windows.PKCS_7_ASN_ENCODING, + } + + var certCtx *windows.CertContext + if err := cryptVerifyMessageSignature(¶, 0, &blob[0], uint32(len(blob)), nil, nil, &certCtx); err != nil { + return "", err + } + defer windows.CertFreeCertificateContext(certCtx) + + return certSubjectFromCertContext(certCtx) +} + +func queryEmbeddedCertSubject(utf16Path *uint16, verify bool) (string, error) { + peBinary, err := windows.CreateFile( + utf16Path, + windows.GENERIC_READ, + windows.FILE_SHARE_READ, + nil, + windows.OPEN_EXISTING, + 0, + 0, + ) + if err != nil { + return "", err + } + defer windows.CloseHandle(peBinary) + + blob, err := extractCertBlob(peBinary) + if err != nil { + return "", err + } + + certSubj, err := querySubjectFromBlob(blob) + if err != nil { + return "", err + } + + if !verify { + return certSubj, nil + } + + wintrustArg := unsafe.Pointer(&windows.WinTrustFileInfo{ + Size: uint32(unsafe.Sizeof(windows.WinTrustFileInfo{})), + FilePath: utf16Path, + File: peBinary, + }) + if err := verifyTrust(windows.WTD_CHOICE_FILE, wintrustArg); err != nil { + // We might still want to know who the cert subject claims to be + // even if the validation has failed (eg for troubleshooting purposes), + // so we return a CertSubjectError. + return "", &CertSubjectError{Err: err, Subject: certSubj} + } + + return certSubj, nil +} + +var ( + _BCRYPT_SHA256_ALGORITHM = &([]uint16{'S', 'H', 'A', '2', '5', '6', 0})[0] + _OID_CERT_STRONG_SIGN_OS_1 = &([]byte("1.3.6.1.4.1.311.72.1.1\x00"))[0] +) + +type _HCATADMIN windows.Handle +type _HCATINFO windows.Handle + +type _CATALOG_INFO struct { + size uint32 + catalogFile [windows.MAX_PATH]uint16 +} + +type _WINTRUST_CATALOG_INFO struct { + size uint32 + catalogVersion uint32 + catalogFilePath *uint16 + memberTag *uint16 + memberFilePath *uint16 + memberFile windows.Handle + pCalculatedFileHash *byte + cbCalculatedFileHash uint32 + catalogContext uintptr + catAdmin _HCATADMIN +} + +func queryCatalogCertSubject(utf16Path *uint16, verify bool) (string, error) { + var catAdmin _HCATADMIN + policy := windows.CertStrongSignPara{ + Size: uint32(unsafe.Sizeof(windows.CertStrongSignPara{})), + InfoChoice: _CERT_STRONG_SIGN_OID_INFO_CHOICE, + InfoOrSerializedInfoOrOID: unsafe.Pointer(_OID_CERT_STRONG_SIGN_OS_1), + } + if err := cryptCATAdminAcquireContext2( + &catAdmin, + nil, + _BCRYPT_SHA256_ALGORITHM, + &policy, + 0, + ); err != nil { + return "", err + } + defer cryptCATAdminReleaseContext(catAdmin, 0) + + // We use windows.CreateFile instead of standard library facilities because: + // 1. Subsequent API calls directly utilize the file's Win32 HANDLE; + // 2. We're going to be hashing the contents of this file, so we want to + // provide a sequential-scan hint to the kernel. + memberFile, err := windows.CreateFile( + utf16Path, + windows.GENERIC_READ, + windows.FILE_SHARE_READ, + nil, + windows.OPEN_EXISTING, + windows.FILE_FLAG_SEQUENTIAL_SCAN, + 0, + ) + if err != nil { + return "", err + } + defer windows.CloseHandle(memberFile) + + var hashLen uint32 + if err := cryptCATAdminCalcHashFromFileHandle2( + catAdmin, + memberFile, + &hashLen, + nil, + 0, + ); err != nil { + return "", err + } + + hashBuf := make([]byte, hashLen) + if err := cryptCATAdminCalcHashFromFileHandle2( + catAdmin, + memberFile, + &hashLen, + &hashBuf[0], + 0, + ); err != nil { + return "", err + } + + catInfoCtx, err := cryptCATAdminEnumCatalogFromHash( + catAdmin, + &hashBuf[0], + hashLen, + 0, + nil, + ) + if err != nil { + if err == windows.ERROR_NOT_FOUND { + err = ErrSigNotFound + } + return "", err + } + defer cryptCATAdminReleaseCatalogContext(catAdmin, catInfoCtx, 0) + + catInfo := _CATALOG_INFO{ + size: uint32(unsafe.Sizeof(_CATALOG_INFO{})), + } + if err := cryptCATAdminCatalogInfoFromContext(catInfoCtx, &catInfo, 0); err != nil { + return "", err + } + + oq, err := newObjectQuery(&catInfo.catalogFile[0]) + if err != nil { + return "", err + } + defer oq.Close() + + certSubj, err := oq.certSubject() + if err != nil { + return "", err + } + + if !verify { + return certSubj, nil + } + + // memberTag is required to be formatted this way. + hbh := strings.ToUpper(hex.EncodeToString(hashBuf)) + memberTag, err := windows.UTF16PtrFromString(hbh) + if err != nil { + return "", err + } + + wintrustArg := unsafe.Pointer(&_WINTRUST_CATALOG_INFO{ + size: uint32(unsafe.Sizeof(_WINTRUST_CATALOG_INFO{})), + catalogFilePath: &catInfo.catalogFile[0], + memberTag: memberTag, + memberFilePath: utf16Path, + memberFile: memberFile, + catAdmin: catAdmin, + }) + if err := verifyTrust(windows.WTD_CHOICE_CATALOG, wintrustArg); err != nil { + // We might still want to know who the cert subject claims to be + // even if the validation has failed (eg for troubleshooting purposes), + // so we return a CertSubjectError. + return "", &CertSubjectError{Err: err, Subject: certSubj} + } + + return certSubj, nil +} + +func verifyTrust(infoType uint32, info unsafe.Pointer) error { + data := &windows.WinTrustData{ + Size: uint32(unsafe.Sizeof(windows.WinTrustData{})), + UIChoice: windows.WTD_UI_NONE, + RevocationChecks: windows.WTD_REVOKE_WHOLECHAIN, // Full revocation checking, as this is called with network connectivity. + UnionChoice: infoType, + StateAction: windows.WTD_STATEACTION_VERIFY, + FileOrCatalogOrBlobOrSgnrOrCert: info, + } + err := windows.WinVerifyTrustEx(windows.InvalidHWND, &windows.WINTRUST_ACTION_GENERIC_VERIFY_V2, data) + + data.StateAction = windows.WTD_STATEACTION_CLOSE + windows.WinVerifyTrustEx(windows.InvalidHWND, &windows.WINTRUST_ACTION_GENERIC_VERIFY_V2, data) + + return err +} diff --git a/vendor/tailscale.com/util/winutil/authenticode/mksyscall.go b/vendor/tailscale.com/util/winutil/authenticode/mksyscall.go new file mode 100644 index 0000000000..8b7cabe6e4 --- /dev/null +++ b/vendor/tailscale.com/util/winutil/authenticode/mksyscall.go @@ -0,0 +1,18 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package authenticode + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go +//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go + +//sys cryptCATAdminAcquireContext2(hCatAdmin *_HCATADMIN, pgSubsystem *windows.GUID, hashAlgorithm *uint16, strongHashPolicy *windows.CertStrongSignPara, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminAcquireContext2 +//sys cryptCATAdminCalcHashFromFileHandle2(hCatAdmin _HCATADMIN, file windows.Handle, pcbHash *uint32, pbHash *byte, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminCalcHashFromFileHandle2 +//sys cryptCATAdminCatalogInfoFromContext(hCatInfo _HCATINFO, catInfo *_CATALOG_INFO, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATCatalogInfoFromContext +//sys cryptCATAdminEnumCatalogFromHash(hCatAdmin _HCATADMIN, pbHash *byte, cbHash uint32, flags uint32, prevCatInfo *_HCATINFO) (ret _HCATINFO, err error) [ret==0] = wintrust.CryptCATAdminEnumCatalogFromHash +//sys cryptCATAdminReleaseCatalogContext(hCatAdmin _HCATADMIN, hCatInfo _HCATINFO, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminReleaseCatalogContext +//sys cryptCATAdminReleaseContext(hCatAdmin _HCATADMIN, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminReleaseContext +//sys cryptMsgClose(cryptMsg windows.Handle) (err error) [int32(failretval)==0] = crypt32.CryptMsgClose +//sys cryptMsgGetParam(cryptMsg windows.Handle, paramType uint32, index uint32, data unsafe.Pointer, dataLen *uint32) (err error) [int32(failretval)==0] = crypt32.CryptMsgGetParam +//sys cryptVerifyMessageSignature(pVerifyPara *_CRYPT_VERIFY_MESSAGE_PARA, signerIndex uint32, pbSignedBlob *byte, cbSignedBlob uint32, pbDecoded *byte, pdbDecoded *uint32, ppSignerCert **windows.CertContext) (err error) [int32(failretval)==0] = crypt32.CryptVerifyMessageSignature +//sys msiGetFileSignatureInformation(signedObjectPath *uint16, flags uint32, certCtx **windows.CertContext, pbHashData *byte, cbHashData *uint32) (ret wingoes.HRESULT) = msi.MsiGetFileSignatureInformationW diff --git a/vendor/tailscale.com/util/winutil/authenticode/zsyscall_windows.go b/vendor/tailscale.com/util/winutil/authenticode/zsyscall_windows.go new file mode 100644 index 0000000000..643721e06a --- /dev/null +++ b/vendor/tailscale.com/util/winutil/authenticode/zsyscall_windows.go @@ -0,0 +1,135 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package authenticode + +import ( + "syscall" + "unsafe" + + "github.com/dblohm7/wingoes" + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modcrypt32 = windows.NewLazySystemDLL("crypt32.dll") + modmsi = windows.NewLazySystemDLL("msi.dll") + modwintrust = windows.NewLazySystemDLL("wintrust.dll") + + procCryptMsgClose = modcrypt32.NewProc("CryptMsgClose") + procCryptMsgGetParam = modcrypt32.NewProc("CryptMsgGetParam") + procCryptVerifyMessageSignature = modcrypt32.NewProc("CryptVerifyMessageSignature") + procMsiGetFileSignatureInformationW = modmsi.NewProc("MsiGetFileSignatureInformationW") + procCryptCATAdminAcquireContext2 = modwintrust.NewProc("CryptCATAdminAcquireContext2") + procCryptCATAdminCalcHashFromFileHandle2 = modwintrust.NewProc("CryptCATAdminCalcHashFromFileHandle2") + procCryptCATAdminEnumCatalogFromHash = modwintrust.NewProc("CryptCATAdminEnumCatalogFromHash") + procCryptCATAdminReleaseCatalogContext = modwintrust.NewProc("CryptCATAdminReleaseCatalogContext") + procCryptCATAdminReleaseContext = modwintrust.NewProc("CryptCATAdminReleaseContext") + procCryptCATCatalogInfoFromContext = modwintrust.NewProc("CryptCATCatalogInfoFromContext") +) + +func cryptMsgClose(cryptMsg windows.Handle) (err error) { + r1, _, e1 := syscall.Syscall(procCryptMsgClose.Addr(), 1, uintptr(cryptMsg), 0, 0) + if int32(r1) == 0 { + err = errnoErr(e1) + } + return +} + +func cryptMsgGetParam(cryptMsg windows.Handle, paramType uint32, index uint32, data unsafe.Pointer, dataLen *uint32) (err error) { + r1, _, e1 := syscall.Syscall6(procCryptMsgGetParam.Addr(), 5, uintptr(cryptMsg), uintptr(paramType), uintptr(index), uintptr(data), uintptr(unsafe.Pointer(dataLen)), 0) + if int32(r1) == 0 { + err = errnoErr(e1) + } + return +} + +func cryptVerifyMessageSignature(pVerifyPara *_CRYPT_VERIFY_MESSAGE_PARA, signerIndex uint32, pbSignedBlob *byte, cbSignedBlob uint32, pbDecoded *byte, pdbDecoded *uint32, ppSignerCert **windows.CertContext) (err error) { + r1, _, e1 := syscall.Syscall9(procCryptVerifyMessageSignature.Addr(), 7, uintptr(unsafe.Pointer(pVerifyPara)), uintptr(signerIndex), uintptr(unsafe.Pointer(pbSignedBlob)), uintptr(cbSignedBlob), uintptr(unsafe.Pointer(pbDecoded)), uintptr(unsafe.Pointer(pdbDecoded)), uintptr(unsafe.Pointer(ppSignerCert)), 0, 0) + if int32(r1) == 0 { + err = errnoErr(e1) + } + return +} + +func msiGetFileSignatureInformation(signedObjectPath *uint16, flags uint32, certCtx **windows.CertContext, pbHashData *byte, cbHashData *uint32) (ret wingoes.HRESULT) { + r0, _, _ := syscall.Syscall6(procMsiGetFileSignatureInformationW.Addr(), 5, uintptr(unsafe.Pointer(signedObjectPath)), uintptr(flags), uintptr(unsafe.Pointer(certCtx)), uintptr(unsafe.Pointer(pbHashData)), uintptr(unsafe.Pointer(cbHashData)), 0) + ret = wingoes.HRESULT(r0) + return +} + +func cryptCATAdminAcquireContext2(hCatAdmin *_HCATADMIN, pgSubsystem *windows.GUID, hashAlgorithm *uint16, strongHashPolicy *windows.CertStrongSignPara, flags uint32) (err error) { + r1, _, e1 := syscall.Syscall6(procCryptCATAdminAcquireContext2.Addr(), 5, uintptr(unsafe.Pointer(hCatAdmin)), uintptr(unsafe.Pointer(pgSubsystem)), uintptr(unsafe.Pointer(hashAlgorithm)), uintptr(unsafe.Pointer(strongHashPolicy)), uintptr(flags), 0) + if int32(r1) == 0 { + err = errnoErr(e1) + } + return +} + +func cryptCATAdminCalcHashFromFileHandle2(hCatAdmin _HCATADMIN, file windows.Handle, pcbHash *uint32, pbHash *byte, flags uint32) (err error) { + r1, _, e1 := syscall.Syscall6(procCryptCATAdminCalcHashFromFileHandle2.Addr(), 5, uintptr(hCatAdmin), uintptr(file), uintptr(unsafe.Pointer(pcbHash)), uintptr(unsafe.Pointer(pbHash)), uintptr(flags), 0) + if int32(r1) == 0 { + err = errnoErr(e1) + } + return +} + +func cryptCATAdminEnumCatalogFromHash(hCatAdmin _HCATADMIN, pbHash *byte, cbHash uint32, flags uint32, prevCatInfo *_HCATINFO) (ret _HCATINFO, err error) { + r0, _, e1 := syscall.Syscall6(procCryptCATAdminEnumCatalogFromHash.Addr(), 5, uintptr(hCatAdmin), uintptr(unsafe.Pointer(pbHash)), uintptr(cbHash), uintptr(flags), uintptr(unsafe.Pointer(prevCatInfo)), 0) + ret = _HCATINFO(r0) + if ret == 0 { + err = errnoErr(e1) + } + return +} + +func cryptCATAdminReleaseCatalogContext(hCatAdmin _HCATADMIN, hCatInfo _HCATINFO, flags uint32) (err error) { + r1, _, e1 := syscall.Syscall(procCryptCATAdminReleaseCatalogContext.Addr(), 3, uintptr(hCatAdmin), uintptr(hCatInfo), uintptr(flags)) + if int32(r1) == 0 { + err = errnoErr(e1) + } + return +} + +func cryptCATAdminReleaseContext(hCatAdmin _HCATADMIN, flags uint32) (err error) { + r1, _, e1 := syscall.Syscall(procCryptCATAdminReleaseContext.Addr(), 2, uintptr(hCatAdmin), uintptr(flags), 0) + if int32(r1) == 0 { + err = errnoErr(e1) + } + return +} + +func cryptCATAdminCatalogInfoFromContext(hCatInfo _HCATINFO, catInfo *_CATALOG_INFO, flags uint32) (err error) { + r1, _, e1 := syscall.Syscall(procCryptCATCatalogInfoFromContext.Addr(), 3, uintptr(hCatInfo), uintptr(unsafe.Pointer(catInfo)), uintptr(flags)) + if int32(r1) == 0 { + err = errnoErr(e1) + } + return +} diff --git a/vendor/tailscale.com/util/winutil/mksyscall.go b/vendor/tailscale.com/util/winutil/mksyscall.go index f54c3273d0..3c5515ee05 100644 --- a/vendor/tailscale.com/util/winutil/mksyscall.go +++ b/vendor/tailscale.com/util/winutil/mksyscall.go @@ -7,4 +7,4 @@ package winutil //go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go //sys queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) [failretval==0] = advapi32.QueryServiceConfig2W -//sys regEnumValue(key registry.Key, index uint32, valueName *uint16, valueNameLen *uint32, reserved *uint32, valueType *uint32, pData *byte, cbData *uint32) (ret error) [failretval!=0] = advapi32.RegEnumValueW +//sys registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret wingoes.HRESULT) = kernel32.RegisterApplicationRestart diff --git a/vendor/tailscale.com/util/winutil/svcdiag_windows.go b/vendor/tailscale.com/util/winutil/svcdiag_windows.go index ce8706a062..cd7c150aaa 100644 --- a/vendor/tailscale.com/util/winutil/svcdiag_windows.go +++ b/vendor/tailscale.com/util/winutil/svcdiag_windows.go @@ -14,6 +14,7 @@ import ( "golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc/mgr" "tailscale.com/types/logger" + "tailscale.com/util/set" ) // LogSvcState obtains the state of the Windows service named rootSvcName and @@ -78,7 +79,7 @@ func walkServices(rootSvcName string, callback walkSvcFunc) error { } }() - seen := make(map[string]struct{}) + seen := set.Set[string]{} for err == nil && len(deps) > 0 { err = func() error { @@ -87,7 +88,7 @@ func walkServices(rootSvcName string, callback walkSvcFunc) error { deps = deps[:len(deps)-1] - seen[curSvc.Name] = struct{}{} + seen.Add(curSvc.Name) curCfg, err := curSvc.Config() if err != nil { @@ -97,7 +98,7 @@ func walkServices(rootSvcName string, callback walkSvcFunc) error { callback(curSvc, curCfg) for _, depName := range curCfg.Dependencies { - if _, ok := seen[depName]; ok { + if seen.Contains(depName) { continue } diff --git a/vendor/tailscale.com/util/winutil/winutil.go b/vendor/tailscale.com/util/winutil/winutil.go index 9ccc9e0858..3ec3f7c990 100644 --- a/vendor/tailscale.com/util/winutil/winutil.go +++ b/vendor/tailscale.com/util/winutil/winutil.go @@ -75,3 +75,27 @@ func IsSIDValidPrincipal(uid string) bool { func LookupPseudoUser(uid string) (*user.User, error) { return lookupPseudoUser(uid) } + +// RegisterForRestartOpts supplies options to RegisterForRestart. +type RegisterForRestartOpts struct { + RestartOnCrash bool // When true, this program will be restarted after a crash. + RestartOnHang bool // When true, this program will be restarted after a hang. + RestartOnUpgrade bool // When true, this program will be restarted after an upgrade. + RestartOnReboot bool // When true, this program will be restarted after a reboot. + UseCmdLineArgs bool // When true, CmdLineArgs will be used as the program's arguments upon restart. Otherwise no arguments will be provided. + CmdLineArgs []string // When UseCmdLineArgs == true, contains the command line arguments, excluding the executable name itself. If nil or empty, the arguments from the current process will be re-used. +} + +// RegisterForRestart registers the current process' restart preferences with +// the Windows Restart Manager. This enables the OS to intelligently restart +// the calling executable as requested via opts. This should be called by any +// programs which need to be restarted by the installer post-update. +// +// This function may be called multiple times; the opts from the most recent +// call will override those from any previous invocations. +// +// This function will only work on GOOS=windows. Trying to run it on any other +// OS will always return nil. +func RegisterForRestart(opts RegisterForRestartOpts) error { + return registerForRestart(opts) +} diff --git a/vendor/tailscale.com/util/winutil/winutil_notwindows.go b/vendor/tailscale.com/util/winutil/winutil_notwindows.go index 0f2c4a83a3..c9a292aae0 100644 --- a/vendor/tailscale.com/util/winutil/winutil_notwindows.go +++ b/vendor/tailscale.com/util/winutil/winutil_notwindows.go @@ -28,3 +28,5 @@ func lookupPseudoUser(uid string) (*user.User, error) { } func IsCurrentProcessElevated() bool { return false } + +func registerForRestart(opts RegisterForRestartOpts) error { return nil } diff --git a/vendor/tailscale.com/util/winutil/winutil_windows.go b/vendor/tailscale.com/util/winutil/winutil_windows.go index 1b2eff00fa..ed516ce6b2 100644 --- a/vendor/tailscale.com/util/winutil/winutil_windows.go +++ b/vendor/tailscale.com/util/winutil/winutil_windows.go @@ -4,24 +4,21 @@ package winutil import ( - "encoding/binary" - "encoding/json" "errors" "fmt" - "io" "log" + "os" "os/exec" "os/user" "runtime" "strings" "syscall" "time" - "unicode/utf16" "unsafe" + "github.com/dblohm7/wingoes" "golang.org/x/sys/windows" "golang.org/x/sys/windows/registry" - "tailscale.com/types/logger" ) const ( @@ -558,164 +555,56 @@ func findHomeDirInRegistry(uid string) (dir string, err error) { } const ( - maxBinaryValueLen = 128 // we'll truncate any binary values longer than this - maxRegValueNameLen = 16384 // maximum length supported by Windows + 1 - initialValueBufLen = 80 // large enough to contain a stringified GUID encoded as UTF-16 + _RESTART_NO_CRASH = 1 + _RESTART_NO_HANG = 2 + _RESTART_NO_PATCH = 4 + _RESTART_NO_REBOOT = 8 ) -const ( - supportInfoKeyRegistry = "Registry" -) +func registerForRestart(opts RegisterForRestartOpts) error { + var flags uint32 -// LogSupportInfo obtains information useful for troubleshooting and support, -// and writes it to the log as a JSON-encoded object. -func LogSupportInfo(logf logger.Logf) { - var b strings.Builder - if err := getSupportInfo(&b); err != nil { - log.Printf("error encoding support info: %v", err) - return + if !opts.RestartOnCrash { + flags |= _RESTART_NO_CRASH } - logf("Support Info: %s", b.String()) -} - -func getSupportInfo(w io.Writer) error { - output := make(map[string]any) - - regInfo, err := getRegistrySupportInfo(registry.LOCAL_MACHINE, []string{regPolicyBase, regBase}) - if err == nil { - output[supportInfoKeyRegistry] = regInfo - } else { - output[supportInfoKeyRegistry] = err + if !opts.RestartOnHang { + flags |= _RESTART_NO_HANG } - - enc := json.NewEncoder(w) - return enc.Encode(output) -} - -type getRegistrySupportInfoBufs struct { - nameBuf []uint16 - valueBuf []byte -} - -func getRegistrySupportInfo(root registry.Key, subKeys []string) (map[string]any, error) { - bufs := getRegistrySupportInfoBufs{ - nameBuf: make([]uint16, maxRegValueNameLen), - valueBuf: make([]byte, initialValueBufLen), - } - - output := make(map[string]any) - - for _, subKey := range subKeys { - if err := getRegSubKey(root, subKey, 5, &bufs, output); err != nil && !errors.Is(err, registry.ErrNotExist) { - return nil, fmt.Errorf("getRegistrySupportInfo: %w", err) - } - } - - return output, nil -} - -func keyString(key registry.Key, subKey string) string { - var keyStr string - switch key { - case registry.CLASSES_ROOT: - keyStr = `HKCR\` - case registry.CURRENT_USER: - keyStr = `HKCU\` - case registry.LOCAL_MACHINE: - keyStr = `HKLM\` - case registry.USERS: - keyStr = `HKU\` - case registry.CURRENT_CONFIG: - keyStr = `HKCC\` - case registry.PERFORMANCE_DATA: - keyStr = `HKPD\` - default: + if !opts.RestartOnUpgrade { + flags |= _RESTART_NO_PATCH } - - return keyStr + subKey -} - -func getRegSubKey(key registry.Key, subKey string, recursionLimit int, bufs *getRegistrySupportInfoBufs, output map[string]any) error { - keyStr := keyString(key, subKey) - k, err := registry.OpenKey(key, subKey, registry.READ) - if err != nil { - return fmt.Errorf("opening %q: %w", keyStr, err) + if !opts.RestartOnReboot { + flags |= _RESTART_NO_REBOOT } - defer k.Close() - kv := make(map[string]any) - index := uint32(0) - -loopValues: - for { - nbuf := bufs.nameBuf - nameLen := uint32(len(nbuf)) - valueType := uint32(0) - vbuf := bufs.valueBuf - valueLen := uint32(len(vbuf)) - - err := regEnumValue(k, index, &nbuf[0], &nameLen, nil, &valueType, &vbuf[0], &valueLen) - switch err { - case windows.ERROR_NO_MORE_ITEMS: - break loopValues - case windows.ERROR_MORE_DATA: - bufs.valueBuf = make([]byte, valueLen) - continue - case nil: - default: - return fmt.Errorf("regEnumValue: %w", err) + var cmdLine *uint16 + if opts.UseCmdLineArgs { + if len(opts.CmdLineArgs) == 0 { + // re-use our current args, excluding the exe name itself + opts.CmdLineArgs = os.Args[1:] } - var value any - - switch valueType { - case registry.SZ, registry.EXPAND_SZ: - value = windows.UTF16PtrToString((*uint16)(unsafe.Pointer(&vbuf[0]))) - case registry.BINARY: - if valueLen > maxBinaryValueLen { - valueLen = maxBinaryValueLen + var b strings.Builder + for _, arg := range opts.CmdLineArgs { + if b.Len() > 0 { + b.WriteByte(' ') } - value = append([]byte{}, vbuf[:valueLen]...) - case registry.DWORD: - value = binary.LittleEndian.Uint32(vbuf[:4]) - case registry.MULTI_SZ: - // Adapted from x/sys/windows/registry/(Key).GetStringsValue - p := (*[1 << 29]uint16)(unsafe.Pointer(&vbuf[0]))[: valueLen/2 : valueLen/2] - var strs []string - if len(p) > 0 { - if p[len(p)-1] == 0 { - p = p[:len(p)-1] - } - strs = make([]string, 0, 5) - from := 0 - for i, c := range p { - if c == 0 { - strs = append(strs, string(utf16.Decode(p[from:i]))) - from = i + 1 - } - } - } - value = strs - case registry.QWORD: - value = binary.LittleEndian.Uint64(vbuf[:8]) - default: - value = fmt.Sprintf("", valueType) + b.WriteString(windows.EscapeArg(arg)) } - kv[windows.UTF16PtrToString(&nbuf[0])] = value - index++ - } - - if recursionLimit > 0 { - if sks, err := k.ReadSubKeyNames(0); err == nil { - for _, sk := range sks { - if err := getRegSubKey(k, sk, recursionLimit-1, bufs, kv); err != nil { - return err - } + if b.Len() > 0 { + var err error + cmdLine, err = windows.UTF16PtrFromString(b.String()) + if err != nil { + return err } } } - output[keyStr] = kv + hr := registerApplicationRestart(cmdLine, flags) + if e := wingoes.ErrorFromHRESULT(hr); e.Failed() { + return e + } + return nil } diff --git a/vendor/tailscale.com/util/winutil/zsyscall_windows.go b/vendor/tailscale.com/util/winutil/zsyscall_windows.go index 930a87522a..77e9f36c82 100644 --- a/vendor/tailscale.com/util/winutil/zsyscall_windows.go +++ b/vendor/tailscale.com/util/winutil/zsyscall_windows.go @@ -6,8 +6,8 @@ import ( "syscall" "unsafe" + "github.com/dblohm7/wingoes" "golang.org/x/sys/windows" - "golang.org/x/sys/windows/registry" ) var _ unsafe.Pointer @@ -40,9 +40,10 @@ func errnoErr(e syscall.Errno) error { var ( modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") - procQueryServiceConfig2W = modadvapi32.NewProc("QueryServiceConfig2W") - procRegEnumValueW = modadvapi32.NewProc("RegEnumValueW") + procQueryServiceConfig2W = modadvapi32.NewProc("QueryServiceConfig2W") + procRegisterApplicationRestart = modkernel32.NewProc("RegisterApplicationRestart") ) func queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) { @@ -53,10 +54,8 @@ func queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, b return } -func regEnumValue(key registry.Key, index uint32, valueName *uint16, valueNameLen *uint32, reserved *uint32, valueType *uint32, pData *byte, cbData *uint32) (ret error) { - r0, _, _ := syscall.Syscall9(procRegEnumValueW.Addr(), 8, uintptr(key), uintptr(index), uintptr(unsafe.Pointer(valueName)), uintptr(unsafe.Pointer(valueNameLen)), uintptr(unsafe.Pointer(reserved)), uintptr(unsafe.Pointer(valueType)), uintptr(unsafe.Pointer(pData)), uintptr(unsafe.Pointer(cbData)), 0) - if r0 != 0 { - ret = syscall.Errno(r0) - } +func registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret wingoes.HRESULT) { + r0, _, _ := syscall.Syscall(procRegisterApplicationRestart.Addr(), 2, uintptr(unsafe.Pointer(cmdLineExclExeName)), uintptr(flags), 0) + ret = wingoes.HRESULT(r0) return } diff --git a/vendor/tailscale.com/version/distro/distro.go b/vendor/tailscale.com/version/distro/distro.go index e319d1ba72..8865a834b9 100644 --- a/vendor/tailscale.com/version/distro/distro.go +++ b/vendor/tailscale.com/version/distro/distro.go @@ -30,6 +30,7 @@ const ( Gokrazy = Distro("gokrazy") WDMyCloud = Distro("wdmycloud") Unraid = Distro("unraid") + Alpine = Distro("alpine") ) var distro lazy.SyncValue[Distro] @@ -93,6 +94,8 @@ func linuxDistro() Distro { return WDMyCloud case have("/etc/unraid-version"): return Unraid + case have("/etc/alpine-release"): + return Alpine } return "" } diff --git a/vendor/tailscale.com/version/prop.go b/vendor/tailscale.com/version/prop.go index 386fcd64e0..daccab29e1 100644 --- a/vendor/tailscale.com/version/prop.go +++ b/vendor/tailscale.com/version/prop.go @@ -88,7 +88,7 @@ func IsAppleTV() bool { return false } return isAppleTV.Get(func() bool { - return strings.EqualFold(os.Getenv("XPC_SERVICE_NAME"), "io.tailscale.ipn.tvos.network-extension") + return strings.EqualFold(os.Getenv("XPC_SERVICE_NAME"), "io.tailscale.ipn.ios.network-extension-tvos") }) } diff --git a/vendor/tailscale.com/wgengine/filter/filter.go b/vendor/tailscale.com/wgengine/filter/filter.go index 350f319e1c..b5ed82a545 100644 --- a/vendor/tailscale.com/wgengine/filter/filter.go +++ b/vendor/tailscale.com/wgengine/filter/filter.go @@ -7,6 +7,7 @@ package filter import ( "fmt" "net/netip" + "slices" "sync" "time" @@ -15,9 +16,11 @@ import ( "tailscale.com/net/flowtrack" "tailscale.com/net/netaddr" "tailscale.com/net/packet" + "tailscale.com/tailcfg" "tailscale.com/tstime/rate" "tailscale.com/types/ipproto" "tailscale.com/types/logger" + "tailscale.com/util/mak" ) // Filter is a stateful packet filter. @@ -322,10 +325,9 @@ func (f *Filter) CheckTCP(srcIP, dstIP netip.Addr, dstPort uint16) Response { return f.RunIn(pkt, 0) } -// AppendCaps appends to base the capabilities that srcIP has talking +// CapsWithValues appends to base the capabilities that srcIP has talking // to dstIP. -func (f *Filter) AppendCaps(base []string, srcIP, dstIP netip.Addr) []string { - ret := base +func (f *Filter) CapsWithValues(srcIP, dstIP netip.Addr) tailcfg.PeerCapMap { var mm matches switch { case srcIP.Is4(): @@ -333,17 +335,23 @@ func (f *Filter) AppendCaps(base []string, srcIP, dstIP netip.Addr) []string { case srcIP.Is6(): mm = f.cap6 } + var out tailcfg.PeerCapMap for _, m := range mm { if !ipInList(srcIP, m.Srcs) { continue } for _, cm := range m.Caps { if cm.Cap != "" && cm.Dst.Contains(dstIP) { - ret = append(ret, cm.Cap) + prev, ok := out[cm.Cap] + if !ok { + mak.Set(&out, cm.Cap, slices.Clone(cm.Values)) + continue + } + out[cm.Cap] = append(prev, cm.Values...) } } } - return ret + return out } // ShieldsUp reports whether this is a "shields up" (block everything diff --git a/vendor/tailscale.com/wgengine/filter/filter_clone.go b/vendor/tailscale.com/wgengine/filter/filter_clone.go index 794550f570..97366d83ce 100644 --- a/vendor/tailscale.com/wgengine/filter/filter_clone.go +++ b/vendor/tailscale.com/wgengine/filter/filter_clone.go @@ -8,6 +8,7 @@ package filter import ( "net/netip" + "tailscale.com/tailcfg" "tailscale.com/types/ipproto" ) @@ -22,7 +23,12 @@ func (src *Match) Clone() *Match { dst.IPProto = append(src.IPProto[:0:0], src.IPProto...) dst.Srcs = append(src.Srcs[:0:0], src.Srcs...) dst.Dsts = append(src.Dsts[:0:0], src.Dsts...) - dst.Caps = append(src.Caps[:0:0], src.Caps...) + if src.Caps != nil { + dst.Caps = make([]CapMatch, len(src.Caps)) + for i := range dst.Caps { + dst.Caps[i] = *src.Caps[i].Clone() + } + } return dst } @@ -33,3 +39,22 @@ var _MatchCloneNeedsRegeneration = Match(struct { Dsts []NetPortRange Caps []CapMatch }{}) + +// Clone makes a deep copy of CapMatch. +// The result aliases no memory with the original. +func (src *CapMatch) Clone() *CapMatch { + if src == nil { + return nil + } + dst := new(CapMatch) + *dst = *src + dst.Values = append(src.Values[:0:0], src.Values...) + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _CapMatchCloneNeedsRegeneration = CapMatch(struct { + Dst netip.Prefix + Cap tailcfg.PeerCapability + Values []tailcfg.RawMessage +}{}) diff --git a/vendor/tailscale.com/wgengine/filter/match.go b/vendor/tailscale.com/wgengine/filter/match.go index 7c73b9630c..2d6f71d197 100644 --- a/vendor/tailscale.com/wgengine/filter/match.go +++ b/vendor/tailscale.com/wgengine/filter/match.go @@ -9,10 +9,11 @@ import ( "strings" "tailscale.com/net/packet" + "tailscale.com/tailcfg" "tailscale.com/types/ipproto" ) -//go:generate go run tailscale.com/cmd/cloner --type=Match +//go:generate go run tailscale.com/cmd/cloner --type=Match,CapMatch // PortRange is a range of TCP and UDP ports. type PortRange struct { @@ -54,7 +55,11 @@ type CapMatch struct { // Cap is the capability that's granted if the destination IP addresses // matches Dst. - Cap string + Cap tailcfg.PeerCapability + + // Values are the raw JSON values of the capability. + // See tailcfg.PeerCapability and tailcfg.PeerCapMap for details. + Values []tailcfg.RawMessage } // Match matches packets from any IP address in Srcs to any ip:port in diff --git a/vendor/tailscale.com/wgengine/filter/tailcfg.go b/vendor/tailscale.com/wgengine/filter/tailcfg.go index de3cde97f3..9f64587d75 100644 --- a/vendor/tailscale.com/wgengine/filter/tailcfg.go +++ b/vendor/tailscale.com/wgengine/filter/tailcfg.go @@ -86,6 +86,13 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) { Cap: cap, }) } + for cap, val := range cm.CapMap { + m.Caps = append(m.Caps, CapMatch{ + Dst: dstNet, + Cap: tailcfg.PeerCapability(cap), + Values: val, + }) + } } } diff --git a/vendor/tailscale.com/wgengine/magicsock/batching_conn.go b/vendor/tailscale.com/wgengine/magicsock/batching_conn.go new file mode 100644 index 0000000000..242f31c372 --- /dev/null +++ b/vendor/tailscale.com/wgengine/magicsock/batching_conn.go @@ -0,0 +1,203 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "errors" + "net" + "net/netip" + "sync" + "sync/atomic" + "syscall" + "time" + + "golang.org/x/net/ipv6" + "tailscale.com/net/neterror" + "tailscale.com/types/nettype" +) + +// xnetBatchReaderWriter defines the batching i/o methods of +// golang.org/x/net/ipv4.PacketConn (and ipv6.PacketConn). +// TODO(jwhited): This should eventually be replaced with the standard library +// implementation of https://github.com/golang/go/issues/45886 +type xnetBatchReaderWriter interface { + xnetBatchReader + xnetBatchWriter +} + +type xnetBatchReader interface { + ReadBatch([]ipv6.Message, int) (int, error) +} + +type xnetBatchWriter interface { + WriteBatch([]ipv6.Message, int) (int, error) +} + +// batchingUDPConn is a UDP socket that provides batched i/o. +type batchingUDPConn struct { + pc nettype.PacketConn + xpc xnetBatchReaderWriter + rxOffload bool // supports UDP GRO or similar + txOffload atomic.Bool // supports UDP GSO or similar + setGSOSizeInControl func(control *[]byte, gsoSize uint16) // typically setGSOSizeInControl(); swappable for testing + getGSOSizeFromControl func(control []byte) (int, error) // typically getGSOSizeFromControl(); swappable for testing + sendBatchPool sync.Pool +} + +func (c *batchingUDPConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) { + if c.rxOffload { + // UDP_GRO is opt-in on Linux via setsockopt(). Once enabled you may + // receive a "monster datagram" from any read call. The ReadFrom() API + // does not support passing the GSO size and is unsafe to use in such a + // case. Other platforms may vary in behavior, but we go with the most + // conservative approach to prevent this from becoming a footgun in the + // future. + return 0, netip.AddrPort{}, errors.New("rx UDP offload is enabled on this socket, single packet reads are unavailable") + } + return c.pc.ReadFromUDPAddrPort(p) +} + +func (c *batchingUDPConn) SetDeadline(t time.Time) error { + return c.pc.SetDeadline(t) +} + +func (c *batchingUDPConn) SetReadDeadline(t time.Time) error { + return c.pc.SetReadDeadline(t) +} + +func (c *batchingUDPConn) SetWriteDeadline(t time.Time) error { + return c.pc.SetWriteDeadline(t) +} + +const ( + // This was initially established for Linux, but may split out to + // GOOS-specific values later. It originates as UDP_MAX_SEGMENTS in the + // kernel's TX path, and UDP_GRO_CNT_MAX for RX. + udpSegmentMaxDatagrams = 64 +) + +const ( + // Exceeding these values results in EMSGSIZE. + maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 + maxIPv6PayloadLen = 1<<16 - 1 - 8 +) + +// coalesceMessages iterates msgs, coalescing them where possible while +// maintaining datagram order. All msgs have their Addr field set to addr. +func (c *batchingUDPConn) coalesceMessages(addr *net.UDPAddr, buffs [][]byte, msgs []ipv6.Message) int { + var ( + base = -1 // index of msg we are currently coalescing into + gsoSize int // segmentation size of msgs[base] + dgramCnt int // number of dgrams coalesced into msgs[base] + endBatch bool // tracking flag to start a new batch on next iteration of buffs + ) + maxPayloadLen := maxIPv4PayloadLen + if addr.IP.To4() == nil { + maxPayloadLen = maxIPv6PayloadLen + } + for i, buff := range buffs { + if i > 0 { + msgLen := len(buff) + baseLenBefore := len(msgs[base].Buffers[0]) + freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore + if msgLen+baseLenBefore <= maxPayloadLen && + msgLen <= gsoSize && + msgLen <= freeBaseCap && + dgramCnt < udpSegmentMaxDatagrams && + !endBatch { + msgs[base].Buffers[0] = append(msgs[base].Buffers[0], make([]byte, msgLen)...) + copy(msgs[base].Buffers[0][baseLenBefore:], buff) + if i == len(buffs)-1 { + c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize)) + } + dgramCnt++ + if msgLen < gsoSize { + // A smaller than gsoSize packet on the tail is legal, but + // it must end the batch. + endBatch = true + } + continue + } + } + if dgramCnt > 1 { + c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize)) + } + // Reset prior to incrementing base since we are preparing to start a + // new potential batch. + endBatch = false + base++ + gsoSize = len(buff) + msgs[base].OOB = msgs[base].OOB[:0] + msgs[base].Buffers[0] = buff + msgs[base].Addr = addr + dgramCnt = 1 + } + return base + 1 +} + +type sendBatch struct { + msgs []ipv6.Message + ua *net.UDPAddr +} + +func (c *batchingUDPConn) getSendBatch() *sendBatch { + batch := c.sendBatchPool.Get().(*sendBatch) + return batch +} + +func (c *batchingUDPConn) putSendBatch(batch *sendBatch) { + for i := range batch.msgs { + batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers, OOB: batch.msgs[i].OOB} + } + c.sendBatchPool.Put(batch) +} + +func (c *batchingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error { + batch := c.getSendBatch() + defer c.putSendBatch(batch) + if addr.Addr().Is6() { + as16 := addr.Addr().As16() + copy(batch.ua.IP, as16[:]) + batch.ua.IP = batch.ua.IP[:16] + } else { + as4 := addr.Addr().As4() + copy(batch.ua.IP, as4[:]) + batch.ua.IP = batch.ua.IP[:4] + } + batch.ua.Port = int(addr.Port()) + var ( + n int + retried bool + ) +retry: + if c.txOffload.Load() { + n = c.coalesceMessages(batch.ua, buffs, batch.msgs) + } else { + for i := range buffs { + batch.msgs[i].Buffers[0] = buffs[i] + batch.msgs[i].Addr = batch.ua + batch.msgs[i].OOB = batch.msgs[i].OOB[:0] + } + n = len(buffs) + } + + err := c.writeBatch(batch.msgs[:n]) + if err != nil && c.txOffload.Load() && neterror.ShouldDisableUDPGSO(err) { + c.txOffload.Store(false) + retried = true + goto retry + } + if retried { + return neterror.ErrUDPGSODisabled{OnLaddr: c.pc.LocalAddr().String(), RetryErr: err} + } + return err +} + +func (c *batchingUDPConn) SyscallConn() (syscall.RawConn, error) { + sc, ok := c.pc.(syscall.Conn) + if !ok { + return nil, errUnsupportedConnType + } + return sc.SyscallConn() +} diff --git a/vendor/tailscale.com/wgengine/magicsock/blockforever_conn.go b/vendor/tailscale.com/wgengine/magicsock/blockforever_conn.go new file mode 100644 index 0000000000..f2e85dcd57 --- /dev/null +++ b/vendor/tailscale.com/wgengine/magicsock/blockforever_conn.go @@ -0,0 +1,55 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "errors" + "net" + "net/netip" + "sync" + "syscall" + "time" +) + +// blockForeverConn is a net.PacketConn whose reads block until it is closed. +type blockForeverConn struct { + mu sync.Mutex + cond *sync.Cond + closed bool +} + +func (c *blockForeverConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) { + c.mu.Lock() + for !c.closed { + c.cond.Wait() + } + c.mu.Unlock() + return 0, netip.AddrPort{}, net.ErrClosed +} + +func (c *blockForeverConn) WriteToUDPAddrPort(p []byte, addr netip.AddrPort) (int, error) { + // Silently drop writes. + return len(p), nil +} + +func (c *blockForeverConn) LocalAddr() net.Addr { + // Return a *net.UDPAddr because lots of code assumes that it will. + return new(net.UDPAddr) +} + +func (c *blockForeverConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return net.ErrClosed + } + c.closed = true + c.cond.Broadcast() + return nil +} + +func (c *blockForeverConn) SetDeadline(t time.Time) error { return errors.New("unimplemented") } +func (c *blockForeverConn) SetReadDeadline(t time.Time) error { return errors.New("unimplemented") } +func (c *blockForeverConn) SetWriteDeadline(t time.Time) error { return errors.New("unimplemented") } +func (c *blockForeverConn) SyscallConn() (syscall.RawConn, error) { return nil, errUnsupportedConnType } diff --git a/vendor/tailscale.com/wgengine/magicsock/debughttp.go b/vendor/tailscale.com/wgengine/magicsock/debughttp.go index f3b61bca41..f26b500446 100644 --- a/vendor/tailscale.com/wgengine/magicsock/debughttp.go +++ b/vendor/tailscale.com/wgengine/magicsock/debughttp.go @@ -101,11 +101,10 @@ func (c *Conn) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) { } sort.Slice(ent, func(i, j int) bool { return ent[i].pub.Less(ent[j].pub) }) - peers := map[key.NodePublic]*tailcfg.Node{} - if c.netMap != nil { - for _, p := range c.netMap.Peers { - peers[p.Key] = p - } + peers := map[key.NodePublic]tailcfg.NodeView{} + for i := range c.peers.LenIter() { + p := c.peers.At(i) + peers[p.Key()] = p } for _, e := range ent { @@ -187,15 +186,15 @@ func printEndpointHTML(w io.Writer, ep *endpoint) { } -func peerDebugName(p *tailcfg.Node) string { - if p == nil { +func peerDebugName(p tailcfg.NodeView) string { + if !p.Valid() { return "" } - n := p.Name + n := p.Name() if base, _, ok := strings.Cut(n, "."); ok { return base } - return p.Hostinfo.Hostname() + return p.Hostinfo().Hostname() } func ipPortLess(a, b netip.AddrPort) bool { diff --git a/vendor/tailscale.com/wgengine/magicsock/debugknobs.go b/vendor/tailscale.com/wgengine/magicsock/debugknobs.go index adfff17d2e..824c677f6e 100644 --- a/vendor/tailscale.com/wgengine/magicsock/debugknobs.go +++ b/vendor/tailscale.com/wgengine/magicsock/debugknobs.go @@ -9,14 +9,14 @@ import ( "tailscale.com/envknob" ) -const linkDebug = true - // Various debugging and experimental tweakables, set by environment // variable. var ( // debugDisco prints verbose logs of active discovery events as // they happen. debugDisco = envknob.RegisterBool("TS_DEBUG_DISCO") + // debugPeerMap prints verbose logs of changes to the peermap. + debugPeerMap = envknob.RegisterBool("TS_DEBUG_MAGICSOCK_PEERMAP") // debugOmitLocalAddresses removes all local interface addresses // from magicsock's discovered local endpoints. Used in some tests. debugOmitLocalAddresses = envknob.RegisterBool("TS_DEBUG_OMIT_LOCAL_ADDRS") @@ -44,8 +44,20 @@ var ( // debugSendCallMeUnknownPeer sends a CallMeMaybe to a non-existent destination every // time we send a real CallMeMaybe to test the PeerGoneNotHere logic. debugSendCallMeUnknownPeer = envknob.RegisterBool("TS_DEBUG_SEND_CALLME_UNKNOWN_PEER") - // Hey you! Adding a new debugknob? Make sure to stub it out in the debugknob_stubs.go - // file too. + // debugBindSocket prints extra debugging about socket rebinding in magicsock. + debugBindSocket = envknob.RegisterBool("TS_DEBUG_MAGICSOCK_BIND_SOCKET") + // debugRingBufferMaxSizeBytes overrides the default size of the endpoint + // history ringbuffer. + debugRingBufferMaxSizeBytes = envknob.RegisterInt("TS_DEBUG_MAGICSOCK_RING_BUFFER_MAX_SIZE_BYTES") + // debugEnablePMTUD enables the peer MTU feature, which does path MTU + // discovery on UDP connections between peers. Currently (2023-09-05) + // this only turns on the don't fragment bit for the magicsock UDP + // sockets. + debugEnablePMTUD = envknob.RegisterOptBool("TS_DEBUG_ENABLE_PMTUD") + // debugPMTUD prints extra debugging about peer MTU path discovery. + debugPMTUD = envknob.RegisterBool("TS_DEBUG_PMTUD") + // Hey you! Adding a new debugknob? Make sure to stub it out in the + // debugknobs_stubs.go file too. ) // inTest reports whether the running program is a test that set the diff --git a/vendor/tailscale.com/wgengine/magicsock/debugknobs_stubs.go b/vendor/tailscale.com/wgengine/magicsock/debugknobs_stubs.go index 60fa01100b..de49865bf2 100644 --- a/vendor/tailscale.com/wgengine/magicsock/debugknobs_stubs.go +++ b/vendor/tailscale.com/wgengine/magicsock/debugknobs_stubs.go @@ -11,6 +11,7 @@ import "tailscale.com/types/opt" // // They're inlinable and the linker can deadcode that's guarded by them to make // smaller binaries. +func debugBindSocket() bool { return false } func debugDisco() bool { return false } func debugOmitLocalAddresses() bool { return false } func logDerpVerbose() bool { return false } @@ -19,7 +20,11 @@ func debugAlwaysDERP() bool { return false } func debugUseDERPHTTP() bool { return false } func debugEnableSilentDisco() bool { return false } func debugSendCallMeUnknownPeer() bool { return false } +func debugPMTUD() bool { return false } func debugUseDERPAddr() string { return "" } func debugUseDerpRouteEnv() string { return "" } func debugUseDerpRoute() opt.Bool { return "" } +func debugEnablePMTUD() opt.Bool { return "" } +func debugRingBufferMaxSizeBytes() int { return 0 } func inTest() bool { return false } +func debugPeerMap() bool { return false } diff --git a/vendor/tailscale.com/wgengine/magicsock/derp.go b/vendor/tailscale.com/wgengine/magicsock/derp.go new file mode 100644 index 0000000000..fa23d1d555 --- /dev/null +++ b/vendor/tailscale.com/wgengine/magicsock/derp.go @@ -0,0 +1,938 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "bufio" + "context" + "fmt" + "hash/fnv" + "math/rand" + "net" + "net/netip" + "reflect" + "runtime" + "sort" + "sync" + "time" + + "github.com/tailscale/wireguard-go/conn" + "tailscale.com/derp" + "tailscale.com/derp/derphttp" + "tailscale.com/health" + "tailscale.com/logtail/backoff" + "tailscale.com/net/dnscache" + "tailscale.com/net/tsaddr" + "tailscale.com/syncs" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/util/mak" + "tailscale.com/util/sysresources" +) + +// useDerpRoute reports whether magicsock should enable the DERP +// return path optimization (Issue 150). +// +// By default it's enabled, unless an environment variable +// or control says to disable it. +func (c *Conn) useDerpRoute() bool { + if b, ok := debugUseDerpRoute().Get(); ok { + return b + } + return c.controlKnobs == nil || !c.controlKnobs.DisableDRPO.Load() +} + +// derpRoute is a route entry for a public key, saying that a certain +// peer should be available at DERP node derpID, as long as the +// current connection for that derpID is dc. (but dc should not be +// used to write directly; it's owned by the read/write loops) +type derpRoute struct { + derpID int + dc *derphttp.Client // don't use directly; see comment above +} + +// removeDerpPeerRoute removes a DERP route entry previously added by addDerpPeerRoute. +func (c *Conn) removeDerpPeerRoute(peer key.NodePublic, derpID int, dc *derphttp.Client) { + c.mu.Lock() + defer c.mu.Unlock() + r2 := derpRoute{derpID, dc} + if r, ok := c.derpRoute[peer]; ok && r == r2 { + delete(c.derpRoute, peer) + } +} + +// addDerpPeerRoute adds a DERP route entry, noting that peer was seen +// on DERP node derpID, at least on the connection identified by dc. +// See issue 150 for details. +func (c *Conn) addDerpPeerRoute(peer key.NodePublic, derpID int, dc *derphttp.Client) { + c.mu.Lock() + defer c.mu.Unlock() + mak.Set(&c.derpRoute, peer, derpRoute{derpID, dc}) +} + +// activeDerp contains fields for an active DERP connection. +type activeDerp struct { + c *derphttp.Client + cancel context.CancelFunc + writeCh chan<- derpWriteRequest + // lastWrite is the time of the last request for its write + // channel (currently even if there was no write). + // It is always non-nil and initialized to a non-zero Time. + lastWrite *time.Time + createTime time.Time +} + +var processStartUnixNano = time.Now().UnixNano() + +// pickDERPFallback returns a non-zero but deterministic DERP node to +// connect to. This is only used if netcheck couldn't find the +// nearest one (for instance, if UDP is blocked and thus STUN latency +// checks aren't working). +// +// c.mu must NOT be held. +func (c *Conn) pickDERPFallback() int { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.wantDerpLocked() { + return 0 + } + ids := c.derpMap.RegionIDs() + if len(ids) == 0 { + // No DERP regions in non-nil map. + return 0 + } + + // TODO: figure out which DERP region most of our peers are using, + // and use that region as our fallback. + // + // If we already had selected something in the past and it has any + // peers, we want to stay on it. If there are no peers at all, + // stay on whatever DERP we previously picked. If we need to pick + // one and have no peer info, pick a region randomly. + // + // We used to do the above for legacy clients, but never updated + // it for disco. + + if c.myDerp != 0 { + return c.myDerp + } + + h := fnv.New64() + fmt.Fprintf(h, "%p/%d", c, processStartUnixNano) // arbitrary + return ids[rand.New(rand.NewSource(int64(h.Sum64()))).Intn(len(ids))] +} + +func (c *Conn) derpRegionCodeLocked(regionID int) string { + if c.derpMap == nil { + return "" + } + if dr, ok := c.derpMap.Regions[regionID]; ok { + return dr.RegionCode + } + return "" +} + +// c.mu must NOT be held. +func (c *Conn) setNearestDERP(derpNum int) (wantDERP bool) { + c.mu.Lock() + defer c.mu.Unlock() + if !c.wantDerpLocked() { + c.myDerp = 0 + health.SetMagicSockDERPHome(0) + return false + } + if derpNum == c.myDerp { + // No change. + return true + } + if c.myDerp != 0 && derpNum != 0 { + metricDERPHomeChange.Add(1) + } + c.myDerp = derpNum + health.SetMagicSockDERPHome(derpNum) + + if c.privateKey.IsZero() { + // No private key yet, so DERP connections won't come up anyway. + // Return early rather than ultimately log a couple lines of noise. + return true + } + + // On change, notify all currently connected DERP servers and + // start connecting to our home DERP if we are not already. + dr := c.derpMap.Regions[derpNum] + if dr == nil { + c.logf("[unexpected] magicsock: derpMap.Regions[%v] is nil", derpNum) + } else { + c.logf("magicsock: home is now derp-%v (%v)", derpNum, c.derpMap.Regions[derpNum].RegionCode) + } + for i, ad := range c.activeDerp { + go ad.c.NotePreferred(i == c.myDerp) + } + c.goDerpConnect(derpNum) + return true +} + +// startDerpHomeConnectLocked starts connecting to our DERP home, if any. +// +// c.mu must be held. +func (c *Conn) startDerpHomeConnectLocked() { + c.goDerpConnect(c.myDerp) +} + +// goDerpConnect starts a goroutine to start connecting to the given +// DERP node. +// +// c.mu may be held, but does not need to be. +func (c *Conn) goDerpConnect(node int) { + if node == 0 { + return + } + go c.derpWriteChanOfAddr(netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, uint16(node)), key.NodePublic{}) +} + +var ( + bufferedDerpWrites int + bufferedDerpWritesOnce sync.Once +) + +// bufferedDerpWritesBeforeDrop returns how many packets writes can be queued +// up the DERP client to write on the wire before we start dropping. +func bufferedDerpWritesBeforeDrop() int { + // For mobile devices, always return the previous minimum value of 32; + // we can do this outside the sync.Once to avoid that overhead. + if runtime.GOOS == "ios" || runtime.GOOS == "android" { + return 32 + } + + bufferedDerpWritesOnce.Do(func() { + // Some rough sizing: for the previous fixed value of 32, the + // total consumed memory can be: + // = numDerpRegions * messages/region * sizeof(message) + // + // For sake of this calculation, assume 100 DERP regions; at + // time of writing (2023-04-03), we have 24. + // + // A reasonable upper bound for the worst-case average size of + // a message is a *disco.CallMeMaybe message with 16 endpoints; + // since sizeof(netip.AddrPort) = 32, that's 512 bytes. Thus: + // = 100 * 32 * 512 + // = 1638400 (1.6MiB) + // + // On a reasonably-small node with 4GiB of memory that's + // connected to each region and handling a lot of load, 1.6MiB + // is about 0.04% of the total system memory. + // + // For sake of this calculation, then, let's double that memory + // usage to 0.08% and scale based on total system memory. + // + // For a 16GiB Linux box, this should buffer just over 256 + // messages. + systemMemory := sysresources.TotalMemory() + memoryUsable := float64(systemMemory) * 0.0008 + + const ( + theoreticalDERPRegions = 100 + messageMaximumSizeBytes = 512 + ) + bufferedDerpWrites = int(memoryUsable / (theoreticalDERPRegions * messageMaximumSizeBytes)) + + // Never drop below the previous minimum value. + if bufferedDerpWrites < 32 { + bufferedDerpWrites = 32 + } + }) + return bufferedDerpWrites +} + +// derpWriteChanOfAddr returns a DERP client for fake UDP addresses that +// represent DERP servers, creating them as necessary. For real UDP +// addresses, it returns nil. +// +// If peer is non-zero, it can be used to find an active reverse +// path, without using addr. +func (c *Conn) derpWriteChanOfAddr(addr netip.AddrPort, peer key.NodePublic) chan<- derpWriteRequest { + if addr.Addr() != tailcfg.DerpMagicIPAddr { + return nil + } + regionID := int(addr.Port()) + + if c.networkDown() { + return nil + } + + c.mu.Lock() + defer c.mu.Unlock() + if !c.wantDerpLocked() || c.closed { + return nil + } + if c.derpMap == nil || c.derpMap.Regions[regionID] == nil { + return nil + } + if c.privateKey.IsZero() { + c.logf("magicsock: DERP lookup of %v with no private key; ignoring", addr) + return nil + } + + // See if we have a connection open to that DERP node ID + // first. If so, might as well use it. (It's a little + // arbitrary whether we use this one vs. the reverse route + // below when we have both.) + ad, ok := c.activeDerp[regionID] + if ok { + *ad.lastWrite = time.Now() + c.setPeerLastDerpLocked(peer, regionID, regionID) + return ad.writeCh + } + + // If we don't have an open connection to the peer's home DERP + // node, see if we have an open connection to a DERP node + // where we'd heard from that peer already. For instance, + // perhaps peer's home is Frankfurt, but they dialed our home DERP + // node in SF to reach us, so we can reply to them using our + // SF connection rather than dialing Frankfurt. (Issue 150) + if !peer.IsZero() && c.useDerpRoute() { + if r, ok := c.derpRoute[peer]; ok { + if ad, ok := c.activeDerp[r.derpID]; ok && ad.c == r.dc { + c.setPeerLastDerpLocked(peer, r.derpID, regionID) + *ad.lastWrite = time.Now() + return ad.writeCh + } + } + } + + why := "home-keep-alive" + if !peer.IsZero() { + why = peer.ShortString() + } + c.logf("magicsock: adding connection to derp-%v for %v", regionID, why) + + firstDerp := false + if c.activeDerp == nil { + firstDerp = true + c.activeDerp = make(map[int]activeDerp) + c.prevDerp = make(map[int]*syncs.WaitGroupChan) + } + + // Note that derphttp.NewRegionClient does not dial the server + // (it doesn't block) so it is safe to do under the c.mu lock. + dc := derphttp.NewRegionClient(c.privateKey, c.logf, c.netMon, func() *tailcfg.DERPRegion { + // Warning: it is not legal to acquire + // magicsock.Conn.mu from this callback. + // It's run from derphttp.Client.connect (via Send, etc) + // and the lock ordering rules are that magicsock.Conn.mu + // must be acquired before derphttp.Client.mu. + // See https://github.com/tailscale/tailscale/issues/3726 + if c.connCtx.Err() != nil { + // We're closing anyway; return nil to stop dialing. + return nil + } + derpMap := c.derpMapAtomic.Load() + if derpMap == nil { + return nil + } + return derpMap.Regions[regionID] + }) + + dc.SetCanAckPings(true) + dc.NotePreferred(c.myDerp == regionID) + dc.SetAddressFamilySelector(derpAddrFamSelector{c}) + dc.DNSCache = dnscache.Get() + + ctx, cancel := context.WithCancel(c.connCtx) + ch := make(chan derpWriteRequest, bufferedDerpWritesBeforeDrop()) + + ad.c = dc + ad.writeCh = ch + ad.cancel = cancel + ad.lastWrite = new(time.Time) + *ad.lastWrite = time.Now() + ad.createTime = time.Now() + c.activeDerp[regionID] = ad + metricNumDERPConns.Set(int64(len(c.activeDerp))) + c.logActiveDerpLocked() + c.setPeerLastDerpLocked(peer, regionID, regionID) + c.scheduleCleanStaleDerpLocked() + + // Build a startGate for the derp reader+writer + // goroutines, so they don't start running until any + // previous generation is closed. + startGate := syncs.ClosedChan() + if prev := c.prevDerp[regionID]; prev != nil { + startGate = prev.DoneChan() + } + // And register a WaitGroup(Chan) for this generation. + wg := syncs.NewWaitGroupChan() + wg.Add(2) + c.prevDerp[regionID] = wg + + if firstDerp { + startGate = c.derpStarted + go func() { + dc.Connect(ctx) + close(c.derpStarted) + c.muCond.Broadcast() + }() + } + + go c.runDerpReader(ctx, addr, dc, wg, startGate) + go c.runDerpWriter(ctx, dc, ch, wg, startGate) + go c.derpActiveFunc() + + return ad.writeCh +} + +// setPeerLastDerpLocked notes that peer is now being written to via +// the provided DERP regionID, and that the peer advertises a DERP +// home region ID of homeID. +// +// If there's any change, it logs. +// +// c.mu must be held. +func (c *Conn) setPeerLastDerpLocked(peer key.NodePublic, regionID, homeID int) { + if peer.IsZero() { + return + } + old := c.peerLastDerp[peer] + if old == regionID { + return + } + c.peerLastDerp[peer] = regionID + + var newDesc string + switch { + case regionID == homeID && regionID == c.myDerp: + newDesc = "shared home" + case regionID == homeID: + newDesc = "their home" + case regionID == c.myDerp: + newDesc = "our home" + case regionID != homeID: + newDesc = "alt" + } + if old == 0 { + c.logf("[v1] magicsock: derp route for %s set to derp-%d (%s)", peer.ShortString(), regionID, newDesc) + } else { + c.logf("[v1] magicsock: derp route for %s changed from derp-%d => derp-%d (%s)", peer.ShortString(), old, regionID, newDesc) + } +} + +// derpReadResult is the type sent by runDerpClient to ReceiveIPv4 +// when a DERP packet is available. +// +// Notably, it doesn't include the derp.ReceivedPacket because we +// don't want to give the receiver access to the aliased []byte. To +// get at the packet contents they need to call copyBuf to copy it +// out, which also releases the buffer. +type derpReadResult struct { + regionID int + n int // length of data received + src key.NodePublic + // copyBuf is called to copy the data to dst. It returns how + // much data was copied, which will be n if dst is large + // enough. copyBuf can only be called once. + // If copyBuf is nil, that's a signal from the sender to ignore + // this message. + copyBuf func(dst []byte) int +} + +// runDerpReader runs in a goroutine for the life of a DERP +// connection, handling received packets. +func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr netip.AddrPort, dc *derphttp.Client, wg *syncs.WaitGroupChan, startGate <-chan struct{}) { + defer wg.Decr() + defer dc.Close() + + select { + case <-startGate: + case <-ctx.Done(): + return + } + + didCopy := make(chan struct{}, 1) + regionID := int(derpFakeAddr.Port()) + res := derpReadResult{regionID: regionID} + var pkt derp.ReceivedPacket + res.copyBuf = func(dst []byte) int { + n := copy(dst, pkt.Data) + didCopy <- struct{}{} + return n + } + + defer health.SetDERPRegionConnectedState(regionID, false) + defer health.SetDERPRegionHealth(regionID, "") + + // peerPresent is the set of senders we know are present on this + // connection, based on messages we've received from the server. + peerPresent := map[key.NodePublic]bool{} + bo := backoff.NewBackoff(fmt.Sprintf("derp-%d", regionID), c.logf, 5*time.Second) + var lastPacketTime time.Time + var lastPacketSrc key.NodePublic + + for { + msg, connGen, err := dc.RecvDetail() + if err != nil { + health.SetDERPRegionConnectedState(regionID, false) + // Forget that all these peers have routes. + for peer := range peerPresent { + delete(peerPresent, peer) + c.removeDerpPeerRoute(peer, regionID, dc) + } + if err == derphttp.ErrClientClosed { + return + } + if c.networkDown() { + c.logf("[v1] magicsock: derp.Recv(derp-%d): network down, closing", regionID) + return + } + select { + case <-ctx.Done(): + return + default: + } + + c.logf("magicsock: [%p] derp.Recv(derp-%d): %v", dc, regionID, err) + + // If our DERP connection broke, it might be because our network + // conditions changed. Start that check. + c.ReSTUN("derp-recv-error") + + // Back off a bit before reconnecting. + bo.BackOff(ctx, err) + select { + case <-ctx.Done(): + return + default: + } + continue + } + bo.BackOff(ctx, nil) // reset + + now := time.Now() + if lastPacketTime.IsZero() || now.Sub(lastPacketTime) > 5*time.Second { + health.NoteDERPRegionReceivedFrame(regionID) + lastPacketTime = now + } + + switch m := msg.(type) { + case derp.ServerInfoMessage: + health.SetDERPRegionConnectedState(regionID, true) + health.SetDERPRegionHealth(regionID, "") // until declared otherwise + c.logf("magicsock: derp-%d connected; connGen=%v", regionID, connGen) + continue + case derp.ReceivedPacket: + pkt = m + res.n = len(m.Data) + res.src = m.Source + if logDerpVerbose() { + c.logf("magicsock: got derp-%v packet: %q", regionID, m.Data) + } + // If this is a new sender we hadn't seen before, remember it and + // register a route for this peer. + if res.src != lastPacketSrc { // avoid map lookup w/ high throughput single peer + lastPacketSrc = res.src + if _, ok := peerPresent[res.src]; !ok { + peerPresent[res.src] = true + c.addDerpPeerRoute(res.src, regionID, dc) + } + } + case derp.PingMessage: + // Best effort reply to the ping. + pingData := [8]byte(m) + go func() { + if err := dc.SendPong(pingData); err != nil { + c.logf("magicsock: derp-%d SendPong error: %v", regionID, err) + } + }() + continue + case derp.HealthMessage: + health.SetDERPRegionHealth(regionID, m.Problem) + case derp.PeerGoneMessage: + switch m.Reason { + case derp.PeerGoneReasonDisconnected: + // Do nothing. + case derp.PeerGoneReasonNotHere: + metricRecvDiscoDERPPeerNotHere.Add(1) + c.logf("[unexpected] magicsock: derp-%d does not know about peer %s, removing route", + regionID, key.NodePublic(m.Peer).ShortString()) + default: + metricRecvDiscoDERPPeerGoneUnknown.Add(1) + c.logf("[unexpected] magicsock: derp-%d peer %s gone, reason %v, removing route", + regionID, key.NodePublic(m.Peer).ShortString(), m.Reason) + } + c.removeDerpPeerRoute(key.NodePublic(m.Peer), regionID, dc) + default: + // Ignore. + continue + } + + select { + case <-ctx.Done(): + return + case c.derpRecvCh <- res: + } + + select { + case <-ctx.Done(): + return + case <-didCopy: + continue + } + } +} + +type derpWriteRequest struct { + addr netip.AddrPort + pubKey key.NodePublic + b []byte // copied; ownership passed to receiver +} + +// runDerpWriter runs in a goroutine for the life of a DERP +// connection, handling received packets. +func (c *Conn) runDerpWriter(ctx context.Context, dc *derphttp.Client, ch <-chan derpWriteRequest, wg *syncs.WaitGroupChan, startGate <-chan struct{}) { + defer wg.Decr() + select { + case <-startGate: + case <-ctx.Done(): + return + } + + for { + select { + case <-ctx.Done(): + return + case wr := <-ch: + err := dc.Send(wr.pubKey, wr.b) + if err != nil { + c.logf("magicsock: derp.Send(%v): %v", wr.addr, err) + metricSendDERPError.Add(1) + } else { + metricSendDERP.Add(1) + } + } + } +} + +func (c *connBind) receiveDERP(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) { + health.ReceiveDERP.Enter() + defer health.ReceiveDERP.Exit() + + for dm := range c.derpRecvCh { + if c.isClosed() { + break + } + n, ep := c.processDERPReadResult(dm, buffs[0]) + if n == 0 { + // No data read occurred. Wait for another packet. + continue + } + metricRecvDataDERP.Add(1) + sizes[0] = n + eps[0] = ep + return 1, nil + } + return 0, net.ErrClosed +} + +func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *endpoint) { + if dm.copyBuf == nil { + return 0, nil + } + var regionID int + n, regionID = dm.n, dm.regionID + ncopy := dm.copyBuf(b) + if ncopy != n { + err := fmt.Errorf("received DERP packet of length %d that's too big for WireGuard buf size %d", n, ncopy) + c.logf("magicsock: %v", err) + return 0, nil + } + + ipp := netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, uint16(regionID)) + if c.handleDiscoMessage(b[:n], ipp, dm.src, discoRXPathDERP) { + return 0, nil + } + + var ok bool + c.mu.Lock() + ep, ok = c.peerMap.endpointForNodeKey(dm.src) + c.mu.Unlock() + if !ok { + // We don't know anything about this node key, nothing to + // record or process. + return 0, nil + } + + ep.noteRecvActivity(ipp) + if stats := c.stats.Load(); stats != nil { + stats.UpdateRxPhysical(ep.nodeAddr, ipp, dm.n) + } + return n, ep +} + +// SetDERPMap controls which (if any) DERP servers are used. +// A nil value means to disable DERP; it's disabled by default. +func (c *Conn) SetDERPMap(dm *tailcfg.DERPMap) { + c.mu.Lock() + defer c.mu.Unlock() + + var derpAddr = debugUseDERPAddr() + if derpAddr != "" { + derpPort := 443 + if debugUseDERPHTTP() { + // Match the port for -dev in derper.go + derpPort = 3340 + } + dm = &tailcfg.DERPMap{ + OmitDefaultRegions: true, + Regions: map[int]*tailcfg.DERPRegion{ + 999: { + RegionID: 999, + Nodes: []*tailcfg.DERPNode{{ + Name: "999dev", + RegionID: 999, + HostName: derpAddr, + DERPPort: derpPort, + }}, + }, + }, + } + } + + if reflect.DeepEqual(dm, c.derpMap) { + return + } + + c.derpMapAtomic.Store(dm) + old := c.derpMap + c.derpMap = dm + if dm == nil { + c.closeAllDerpLocked("derp-disabled") + return + } + + // Reconnect any DERP region that changed definitions. + if old != nil { + changes := false + for rid, oldDef := range old.Regions { + if reflect.DeepEqual(oldDef, dm.Regions[rid]) { + continue + } + changes = true + if rid == c.myDerp { + c.myDerp = 0 + } + c.closeDerpLocked(rid, "derp-region-redefined") + } + if changes { + c.logActiveDerpLocked() + } + } + + go c.ReSTUN("derp-map-update") +} +func (c *Conn) wantDerpLocked() bool { return c.derpMap != nil } + +// c.mu must be held. +func (c *Conn) closeAllDerpLocked(why string) { + if len(c.activeDerp) == 0 { + return // without the useless log statement + } + for i := range c.activeDerp { + c.closeDerpLocked(i, why) + } + c.logActiveDerpLocked() +} + +// DebugBreakDERPConns breaks all DERP connections for debug/testing reasons. +func (c *Conn) DebugBreakDERPConns() error { + c.mu.Lock() + defer c.mu.Unlock() + if len(c.activeDerp) == 0 { + c.logf("magicsock: DebugBreakDERPConns: no active DERP connections") + return nil + } + c.closeAllDerpLocked("debug-break-derp") + c.startDerpHomeConnectLocked() + return nil +} + +// maybeCloseDERPsOnRebind, in response to a rebind, closes all +// DERP connections that don't have a local address in okayLocalIPs +// and pings all those that do. +func (c *Conn) maybeCloseDERPsOnRebind(okayLocalIPs []netip.Prefix) { + c.mu.Lock() + defer c.mu.Unlock() + for regionID, ad := range c.activeDerp { + la, err := ad.c.LocalAddr() + if err != nil { + c.closeOrReconnectDERPLocked(regionID, "rebind-no-localaddr") + continue + } + if !tsaddr.PrefixesContainsIP(okayLocalIPs, la.Addr()) { + c.closeOrReconnectDERPLocked(regionID, "rebind-default-route-change") + continue + } + regionID := regionID + dc := ad.c + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + if err := dc.Ping(ctx); err != nil { + c.mu.Lock() + defer c.mu.Unlock() + c.closeOrReconnectDERPLocked(regionID, "rebind-ping-fail") + return + } + c.logf("post-rebind ping of DERP region %d okay", regionID) + }() + } + c.logActiveDerpLocked() +} + +// closeOrReconnectDERPLocked closes the DERP connection to the +// provided regionID and starts reconnecting it if it's our current +// home DERP. +// +// why is a reason for logging. +// +// c.mu must be held. +func (c *Conn) closeOrReconnectDERPLocked(regionID int, why string) { + c.closeDerpLocked(regionID, why) + if !c.privateKey.IsZero() && c.myDerp == regionID { + c.startDerpHomeConnectLocked() + } +} + +// c.mu must be held. +// It is the responsibility of the caller to call logActiveDerpLocked after any set of closes. +func (c *Conn) closeDerpLocked(regionID int, why string) { + if ad, ok := c.activeDerp[regionID]; ok { + c.logf("magicsock: closing connection to derp-%v (%v), age %v", regionID, why, time.Since(ad.createTime).Round(time.Second)) + go ad.c.Close() + ad.cancel() + delete(c.activeDerp, regionID) + metricNumDERPConns.Set(int64(len(c.activeDerp))) + } +} + +// c.mu must be held. +func (c *Conn) logActiveDerpLocked() { + now := time.Now() + c.logf("magicsock: %v active derp conns%s", len(c.activeDerp), logger.ArgWriter(func(buf *bufio.Writer) { + if len(c.activeDerp) == 0 { + return + } + buf.WriteString(":") + c.foreachActiveDerpSortedLocked(func(node int, ad activeDerp) { + fmt.Fprintf(buf, " derp-%d=cr%v,wr%v", node, simpleDur(now.Sub(ad.createTime)), simpleDur(now.Sub(*ad.lastWrite))) + }) + })) +} + +// c.mu must be held. +func (c *Conn) foreachActiveDerpSortedLocked(fn func(regionID int, ad activeDerp)) { + if len(c.activeDerp) < 2 { + for id, ad := range c.activeDerp { + fn(id, ad) + } + return + } + ids := make([]int, 0, len(c.activeDerp)) + for id := range c.activeDerp { + ids = append(ids, id) + } + sort.Ints(ids) + for _, id := range ids { + fn(id, c.activeDerp[id]) + } +} + +func (c *Conn) cleanStaleDerp() { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return + } + c.derpCleanupTimerArmed = false + + tooOld := time.Now().Add(-derpInactiveCleanupTime) + dirty := false + someNonHomeOpen := false + for i, ad := range c.activeDerp { + if i == c.myDerp { + continue + } + if ad.lastWrite.Before(tooOld) { + c.closeDerpLocked(i, "idle") + dirty = true + } else { + someNonHomeOpen = true + } + } + if dirty { + c.logActiveDerpLocked() + } + if someNonHomeOpen { + c.scheduleCleanStaleDerpLocked() + } +} + +func (c *Conn) scheduleCleanStaleDerpLocked() { + if c.derpCleanupTimerArmed { + // Already going to fire soon. Let the existing one + // fire lest it get infinitely delayed by repeated + // calls to scheduleCleanStaleDerpLocked. + return + } + c.derpCleanupTimerArmed = true + if c.derpCleanupTimer != nil { + c.derpCleanupTimer.Reset(derpCleanStaleInterval) + } else { + c.derpCleanupTimer = time.AfterFunc(derpCleanStaleInterval, c.cleanStaleDerp) + } +} + +// DERPs reports the number of active DERP connections. +func (c *Conn) DERPs() int { + c.mu.Lock() + defer c.mu.Unlock() + + return len(c.activeDerp) +} + +func (c *Conn) derpRegionCodeOfIDLocked(regionID int) string { + if c.derpMap == nil { + return "" + } + if r, ok := c.derpMap.Regions[regionID]; ok { + return r.RegionCode + } + return "" +} + +// derpAddrFamSelector is the derphttp.AddressFamilySelector we pass +// to derphttp.Client.SetAddressFamilySelector. +// +// It provides the hint as to whether in an IPv4-vs-IPv6 race that +// IPv4 should be held back a bit to give IPv6 a better-than-50/50 +// chance of winning. We only return true when we believe IPv6 will +// work anyway, so we don't artificially delay the connection speed. +type derpAddrFamSelector struct{ c *Conn } + +func (s derpAddrFamSelector) PreferIPv6() bool { + if r := s.c.lastNetCheckReport.Load(); r != nil { + return r.IPv6 + } + return false +} + +const ( + // derpInactiveCleanupTime is how long a non-home DERP connection + // needs to be idle (last written to) before we close it. + derpInactiveCleanupTime = 60 * time.Second + + // derpCleanStaleInterval is how often cleanStaleDerp runs when there + // are potentially-stale DERP connections to close. + derpCleanStaleInterval = 15 * time.Second +) diff --git a/vendor/tailscale.com/wgengine/magicsock/endpoint.go b/vendor/tailscale.com/wgengine/magicsock/endpoint.go new file mode 100644 index 0000000000..1825abca2e --- /dev/null +++ b/vendor/tailscale.com/wgengine/magicsock/endpoint.go @@ -0,0 +1,1241 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "bufio" + "context" + "encoding/binary" + "errors" + "fmt" + "math" + "math/rand" + "net" + "net/netip" + "reflect" + "runtime" + "sync" + "sync/atomic" + "time" + + "golang.org/x/crypto/poly1305" + xmaps "golang.org/x/exp/maps" + "tailscale.com/disco" + "tailscale.com/ipn/ipnstate" + "tailscale.com/net/stun" + "tailscale.com/net/tstun" + "tailscale.com/tailcfg" + "tailscale.com/tstime/mono" + "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/types/views" + "tailscale.com/util/mak" + "tailscale.com/util/ringbuffer" +) + +// endpoint is a wireguard/conn.Endpoint. In wireguard-go and kernel WireGuard +// there is only one endpoint for a peer, but in Tailscale we distribute a +// number of possible endpoints for a peer which would include the all the +// likely addresses at which a peer may be reachable. This endpoint type holds +// the information required that when WiregGuard-Go wants to send to a +// particular peer (essentally represented by this endpoint type), the send +// function can use the currnetly best known Tailscale endpoint to send packets +// to the peer. +type endpoint struct { + // atomically accessed; declared first for alignment reasons + lastRecv mono.Time + numStopAndResetAtomic int64 + debugUpdates *ringbuffer.RingBuffer[EndpointChange] + + // These fields are initialized once and never modified. + c *Conn + nodeID tailcfg.NodeID + publicKey key.NodePublic // peer public key (for WireGuard + DERP) + publicKeyHex string // cached output of publicKey.UntypedHexString + fakeWGAddr netip.AddrPort // the UDP address we tell wireguard-go we're using + nodeAddr netip.Addr // the node's first tailscale address; used for logging & wireguard rate-limiting (Issue 6686) + + disco atomic.Pointer[endpointDisco] // if the peer supports disco, the key and short string + + // mu protects all following fields. + mu sync.Mutex // Lock ordering: Conn.mu, then endpoint.mu + + heartBeatTimer *time.Timer // nil when idle + lastSend mono.Time // last time there was outgoing packets sent to this peer (from wireguard-go) + lastFullPing mono.Time // last time we pinged all disco or wireguard only endpoints + derpAddr netip.AddrPort // fallback/bootstrap path, if non-zero (non-zero for well-behaved clients) + + bestAddr addrLatency // best non-DERP path; zero if none + bestAddrAt mono.Time // time best address re-confirmed + trustBestAddrUntil mono.Time // time when bestAddr expires + sentPing map[stun.TxID]sentPing + endpointState map[netip.AddrPort]*endpointState + isCallMeMaybeEP map[netip.AddrPort]bool + + // The following fields are related to the new "silent disco" + // implementation that's a WIP as of 2022-10-20. + // See #540 for background. + heartbeatDisabled bool + + expired bool // whether the node has expired + isWireguardOnly bool // whether the endpoint is WireGuard only +} + +// endpointDisco is the current disco key and short string for an endpoint. This +// structure is immutable. +type endpointDisco struct { + key key.DiscoPublic // for discovery messages. + short string // ShortString of discoKey. +} + +type sentPing struct { + to netip.AddrPort + at mono.Time + timer *time.Timer // timeout timer + purpose discoPingPurpose + res *ipnstate.PingResult // nil unless CLI ping + cb func(*ipnstate.PingResult) // nil unless CLI ping +} + +// endpointState is some state and history for a specific endpoint of +// a endpoint. (The subject is the endpoint.endpointState +// map key) +type endpointState struct { + // all fields guarded by endpoint.mu + + // lastPing is the last (outgoing) ping time. + lastPing mono.Time + + // lastGotPing, if non-zero, means that this was an endpoint + // that we learned about at runtime (from an incoming ping) + // and that is not in the network map. If so, we keep the time + // updated and use it to discard old candidates. + lastGotPing time.Time + + // lastGotPingTxID contains the TxID for the last incoming ping. This is + // used to de-dup incoming pings that we may see on both the raw disco + // socket on Linux, and UDP socket. We cannot rely solely on the raw socket + // disco handling due to https://github.com/tailscale/tailscale/issues/7078. + lastGotPingTxID stun.TxID + + // callMeMaybeTime, if non-zero, is the time this endpoint + // was advertised last via a call-me-maybe disco message. + callMeMaybeTime time.Time + + recentPongs []pongReply // ring buffer up to pongHistoryCount entries + recentPong uint16 // index into recentPongs of most recent; older before, wrapped + + index int16 // index in nodecfg.Node.Endpoints; meaningless if lastGotPing non-zero +} + +// clear removes all derived / probed state from an endpointState. +func (s *endpointState) clear() { + *s = endpointState{ + index: s.index, + lastGotPing: s.lastGotPing, + } +} + +// pongHistoryCount is how many pongReply values we keep per endpointState +const pongHistoryCount = 64 + +type pongReply struct { + latency time.Duration + pongAt mono.Time // when we received the pong + from netip.AddrPort // the pong's src (usually same as endpoint map key) + pongSrc netip.AddrPort // what they reported they heard +} + +// EndpointChange is a structure containing information about changes made to a +// particular endpoint. This is not a stable interface and could change at any +// time. +type EndpointChange struct { + When time.Time // when the change occurred + What string // what this change is + From any `json:",omitempty"` // information about the previous state + To any `json:",omitempty"` // information about the new state +} + +// shouldDeleteLocked reports whether we should delete this endpoint. +func (st *endpointState) shouldDeleteLocked() bool { + switch { + case !st.callMeMaybeTime.IsZero(): + return false + case st.lastGotPing.IsZero(): + // This was an endpoint from the network map. Is it still in the network map? + return st.index == indexSentinelDeleted + default: + // This was an endpoint discovered at runtime. + return time.Since(st.lastGotPing) > sessionActiveTimeout + } +} + +// latencyLocked returns the most recent latency measurement, if any. +// endpoint.mu must be held. +func (st *endpointState) latencyLocked() (lat time.Duration, ok bool) { + if len(st.recentPongs) == 0 { + return 0, false + } + return st.recentPongs[st.recentPong].latency, true +} + +// endpoint.mu must be held. +func (st *endpointState) addPongReplyLocked(r pongReply) { + if n := len(st.recentPongs); n < pongHistoryCount { + st.recentPong = uint16(n) + st.recentPongs = append(st.recentPongs, r) + return + } + i := st.recentPong + 1 + if i == pongHistoryCount { + i = 0 + } + st.recentPongs[i] = r + st.recentPong = i +} + +func (de *endpoint) deleteEndpointLocked(why string, ep netip.AddrPort) { + de.debugUpdates.Add(EndpointChange{ + When: time.Now(), + What: "deleteEndpointLocked-" + why, + From: ep, + }) + delete(de.endpointState, ep) + if de.bestAddr.AddrPort == ep { + de.debugUpdates.Add(EndpointChange{ + When: time.Now(), + What: "deleteEndpointLocked-bestAddr-" + why, + From: de.bestAddr, + }) + de.bestAddr = addrLatency{} + } +} + +// initFakeUDPAddr populates fakeWGAddr with a globally unique fake UDPAddr. +// The current implementation just uses the pointer value of de jammed into an IPv6 +// address, but it could also be, say, a counter. +func (de *endpoint) initFakeUDPAddr() { + var addr [16]byte + addr[0] = 0xfd + addr[1] = 0x00 + binary.BigEndian.PutUint64(addr[2:], uint64(reflect.ValueOf(de).Pointer())) + de.fakeWGAddr = netip.AddrPortFrom(netip.AddrFrom16(addr).Unmap(), 12345) +} + +// noteRecvActivity records receive activity on de, and invokes +// Conn.noteRecvActivity no more than once every 10s. +func (de *endpoint) noteRecvActivity(ipp netip.AddrPort) { + now := mono.Now() + + // TODO(raggi): this probably applies relatively equally well to disco + // managed endpoints, but that would be a less conservative change. + if de.isWireguardOnly { + de.mu.Lock() + de.bestAddr.AddrPort = ipp + de.bestAddrAt = now + de.trustBestAddrUntil = now.Add(5 * time.Second) + de.mu.Unlock() + } + + elapsed := now.Sub(de.lastRecv.LoadAtomic()) + if elapsed > 10*time.Second { + de.lastRecv.StoreAtomic(now) + + if de.c.noteRecvActivity == nil { + return + } + de.c.noteRecvActivity(de.publicKey) + } +} + +func (de *endpoint) discoShort() string { + var short string + if d := de.disco.Load(); d != nil { + short = d.short + } + return short +} + +// String exists purely so wireguard-go internals can log.Printf("%v") +// its internal conn.Endpoints and we don't end up with data races +// from fmt (via log) reading mutex fields and such. +func (de *endpoint) String() string { + return fmt.Sprintf("magicsock.endpoint{%v, %v}", de.publicKey.ShortString(), de.discoShort()) +} + +func (de *endpoint) ClearSrc() {} +func (de *endpoint) SrcToString() string { panic("unused") } // unused by wireguard-go +func (de *endpoint) SrcIP() netip.Addr { panic("unused") } // unused by wireguard-go +func (de *endpoint) DstToString() string { return de.publicKeyHex } +func (de *endpoint) DstIP() netip.Addr { return de.nodeAddr } // see tailscale/tailscale#6686 +func (de *endpoint) DstToBytes() []byte { return packIPPort(de.fakeWGAddr) } + +// addrForSendLocked returns the address(es) that should be used for +// sending the next packet. Zero, one, or both of UDP address and DERP +// addr may be non-zero. If the endpoint is WireGuard only and does not have +// latency information, a bool is returned to indiciate that the +// WireGuard latency discovery pings should be sent. +// +// de.mu must be held. +func (de *endpoint) addrForSendLocked(now mono.Time) (udpAddr, derpAddr netip.AddrPort, sendWGPing bool) { + udpAddr = de.bestAddr.AddrPort + + if udpAddr.IsValid() && !now.After(de.trustBestAddrUntil) { + return udpAddr, netip.AddrPort{}, false + } + + if de.isWireguardOnly { + // If the endpoint is wireguard-only, we don't have a DERP + // address to send to, so we have to send to the UDP address. + udpAddr, shouldPing := de.addrForWireGuardSendLocked(now) + return udpAddr, netip.AddrPort{}, shouldPing + } + + // We had a bestAddr but it expired so send both to it + // and DERP. + return udpAddr, de.derpAddr, false +} + +// addrForWireGuardSendLocked returns the address that should be used for +// sending the next packet. If a packet has never or not recently been sent to +// the endpoint, then a randomly selected address for the endpoint is returned, +// as well as a bool indiciating that WireGuard discovery pings should be started. +// If the addresses have latency information available, then the address with the +// best latency is used. +// +// de.mu must be held. +func (de *endpoint) addrForWireGuardSendLocked(now mono.Time) (udpAddr netip.AddrPort, shouldPing bool) { + if len(de.endpointState) == 0 { + de.c.logf("magicsock: addrForSendWireguardLocked: [unexpected] no candidates available for endpoint") + return udpAddr, false + } + + // lowestLatency is a high duration initially, so we + // can be sure we're going to have a duration lower than this + // for the first latency retrieved. + lowestLatency := time.Hour + var oldestPing mono.Time + for ipp, state := range de.endpointState { + if oldestPing.IsZero() { + oldestPing = state.lastPing + } else if state.lastPing.Before(oldestPing) { + oldestPing = state.lastPing + } + + if latency, ok := state.latencyLocked(); ok { + if latency < lowestLatency || latency == lowestLatency && ipp.Addr().Is6() { + // If we have the same latency,IPv6 is prioritized. + // TODO(catzkorn): Consider a small increase in latency to use + // IPv6 in comparison to IPv4, when possible. + lowestLatency = latency + udpAddr = ipp + } + } + } + needPing := len(de.endpointState) > 1 && now.Sub(oldestPing) > wireguardPingInterval + + if !udpAddr.IsValid() { + candidates := xmaps.Keys(de.endpointState) + + // Randomly select an address to use until we retrieve latency information + // and give it a short trustBestAddrUntil time so we avoid flapping between + // addresses while waiting on latency information to be populated. + udpAddr = candidates[rand.Intn(len(candidates))] + } + + de.bestAddr.AddrPort = udpAddr + // Only extend trustBestAddrUntil by one second to avoid packet + // reordering and/or CPU usage from random selection during the first + // second. We should receive a response due to a WireGuard handshake in + // less than one second in good cases, in which case this will be then + // extended to 15 seconds. + de.trustBestAddrUntil = now.Add(time.Second) + return udpAddr, needPing +} + +// heartbeat is called every heartbeatInterval to keep the best UDP path alive, +// or kick off discovery of other paths. +func (de *endpoint) heartbeat() { + de.mu.Lock() + defer de.mu.Unlock() + + de.heartBeatTimer = nil + + if de.heartbeatDisabled { + // If control override to disable heartBeatTimer set, return early. + return + } + + if de.lastSend.IsZero() { + // Shouldn't happen. + return + } + + if mono.Since(de.lastSend) > sessionActiveTimeout { + // Session's idle. Stop heartbeating. + de.c.dlogf("[v1] magicsock: disco: ending heartbeats for idle session to %v (%v)", de.publicKey.ShortString(), de.discoShort()) + return + } + + now := mono.Now() + udpAddr, _, _ := de.addrForSendLocked(now) + if udpAddr.IsValid() { + // We have a preferred path. Ping that every 2 seconds. + de.startDiscoPingLocked(udpAddr, now, pingHeartbeat, 0, nil, nil) + } + + if de.wantFullPingLocked(now) { + de.sendDiscoPingsLocked(now, true) + } + + de.heartBeatTimer = time.AfterFunc(heartbeatInterval, de.heartbeat) +} + +// wantFullPingLocked reports whether we should ping to all our peers looking for +// a better path. +// +// de.mu must be held. +func (de *endpoint) wantFullPingLocked(now mono.Time) bool { + if runtime.GOOS == "js" { + return false + } + if !de.bestAddr.IsValid() || de.lastFullPing.IsZero() { + return true + } + if now.After(de.trustBestAddrUntil) { + return true + } + if de.bestAddr.latency <= goodEnoughLatency { + return false + } + if now.Sub(de.lastFullPing) >= upgradeInterval { + return true + } + return false +} + +func (de *endpoint) noteActiveLocked() { + de.lastSend = mono.Now() + if de.heartBeatTimer == nil && !de.heartbeatDisabled { + de.heartBeatTimer = time.AfterFunc(heartbeatInterval, de.heartbeat) + } +} + +// cliPing starts a ping for the "tailscale ping" command. res is value to call cb with, +// already partially filled. +func (de *endpoint) cliPing(res *ipnstate.PingResult, size int, cb func(*ipnstate.PingResult)) { + de.mu.Lock() + defer de.mu.Unlock() + + if de.expired { + res.Err = errExpired.Error() + cb(res) + return + } + + now := mono.Now() + udpAddr, derpAddr, _ := de.addrForSendLocked(now) + + if derpAddr.IsValid() { + de.startDiscoPingLocked(derpAddr, now, pingCLI, size, res, cb) + } + if udpAddr.IsValid() && now.Before(de.trustBestAddrUntil) { + // Already have an active session, so just ping the address we're using. + // Otherwise "tailscale ping" results to a node on the local network + // can look like they're bouncing between, say 10.0.0.0/9 and the peer's + // IPv6 address, both 1ms away, and it's random who replies first. + de.startDiscoPingLocked(udpAddr, now, pingCLI, size, res, cb) + } else { + for ep := range de.endpointState { + de.startDiscoPingLocked(ep, now, pingCLI, size, res, cb) + } + } + de.noteActiveLocked() +} + +var ( + errExpired = errors.New("peer's node key has expired") + errNoUDPOrDERP = errors.New("no UDP or DERP addr") +) + +func (de *endpoint) send(buffs [][]byte) error { + de.mu.Lock() + if de.expired { + de.mu.Unlock() + return errExpired + } + + now := mono.Now() + udpAddr, derpAddr, startWGPing := de.addrForSendLocked(now) + + if de.isWireguardOnly { + if startWGPing { + de.sendWireGuardOnlyPingsLocked(now) + } + } else if !udpAddr.IsValid() || now.After(de.trustBestAddrUntil) { + de.sendDiscoPingsLocked(now, true) + } + de.noteActiveLocked() + de.mu.Unlock() + + if !udpAddr.IsValid() && !derpAddr.IsValid() { + return errNoUDPOrDERP + } + var err error + if udpAddr.IsValid() { + _, err = de.c.sendUDPBatch(udpAddr, buffs) + + // If the error is known to indicate that the endpoint is no longer + // usable, clear the endpoint statistics so that the next send will + // re-evaluate the best endpoint. + if err != nil && isBadEndpointErr(err) { + de.noteBadEndpoint(udpAddr) + } + + // TODO(raggi): needs updating for accuracy, as in error conditions we may have partial sends. + if stats := de.c.stats.Load(); err == nil && stats != nil { + var txBytes int + for _, b := range buffs { + txBytes += len(b) + } + stats.UpdateTxPhysical(de.nodeAddr, udpAddr, txBytes) + } + } + if derpAddr.IsValid() { + allOk := true + for _, buff := range buffs { + ok, _ := de.c.sendAddr(derpAddr, de.publicKey, buff) + if stats := de.c.stats.Load(); stats != nil { + stats.UpdateTxPhysical(de.nodeAddr, derpAddr, len(buff)) + } + if !ok { + allOk = false + } + } + if allOk { + return nil + } + } + return err +} + +func (de *endpoint) discoPingTimeout(txid stun.TxID) { + de.mu.Lock() + defer de.mu.Unlock() + sp, ok := de.sentPing[txid] + if !ok { + return + } + if debugDisco() || !de.bestAddr.IsValid() || mono.Now().After(de.trustBestAddrUntil) { + de.c.dlogf("[v1] magicsock: disco: timeout waiting for pong %x from %v (%v, %v)", txid[:6], sp.to, de.publicKey.ShortString(), de.discoShort()) + } + de.removeSentDiscoPingLocked(txid, sp) +} + +// forgetDiscoPing is called by a timer when a ping either fails to send or +// has taken too long to get a pong reply. +func (de *endpoint) forgetDiscoPing(txid stun.TxID) { + de.mu.Lock() + defer de.mu.Unlock() + if sp, ok := de.sentPing[txid]; ok { + de.removeSentDiscoPingLocked(txid, sp) + } +} + +func (de *endpoint) removeSentDiscoPingLocked(txid stun.TxID, sp sentPing) { + // Stop the timer for the case where sendPing failed to write to UDP. + // In the case of a timer already having fired, this is a no-op: + sp.timer.Stop() + delete(de.sentPing, txid) +} + +// discoPingSize is the size of a complete disco ping packet, without any padding. +const discoPingSize = len(disco.Magic) + key.DiscoPublicRawLen + disco.NonceLen + + poly1305.TagSize + disco.MessageHeaderLen + disco.PingLen + +// sendDiscoPing sends a ping with the provided txid to ep using de's discoKey. size +// is the desired disco message size, including all disco headers but excluding IP/UDP +// headers. +// +// The caller (startPingLocked) should've already recorded the ping in +// sentPing and set up the timer. +// +// The caller should use de.discoKey as the discoKey argument. +// It is passed in so that sendDiscoPing doesn't need to lock de.mu. +func (de *endpoint) sendDiscoPing(ep netip.AddrPort, discoKey key.DiscoPublic, txid stun.TxID, size int, logLevel discoLogLevel) { + padding := 0 + if size > int(tstun.DefaultMTU()) { + size = int(tstun.DefaultMTU()) + } + if size-discoPingSize > 0 { + padding = size - discoPingSize + } + sent, _ := de.c.sendDiscoMessage(ep, de.publicKey, discoKey, &disco.Ping{ + TxID: [12]byte(txid), + NodeKey: de.c.publicKeyAtomic.Load(), + Padding: padding, + }, logLevel) + if !sent { + de.forgetDiscoPing(txid) + } +} + +// discoPingPurpose is the reason why a discovery ping message was sent. +type discoPingPurpose int + +//go:generate go run tailscale.com/cmd/addlicense -file discopingpurpose_string.go go run golang.org/x/tools/cmd/stringer -type=discoPingPurpose -trimprefix=ping +const ( + // pingDiscovery means that purpose of a ping was to see if a + // path was valid. + pingDiscovery discoPingPurpose = iota + + // pingHeartbeat means that purpose of a ping was whether a + // peer was still there. + pingHeartbeat + + // pingCLI means that the user is running "tailscale ping" + // from the CLI. These types of pings can go over DERP. + pingCLI +) + +// startDiscoPingLocked sends a disco ping to ep in a separate +// goroutine. res and cb are for returning the results of CLI pings, +// otherwise they are nil. +func (de *endpoint) startDiscoPingLocked(ep netip.AddrPort, now mono.Time, purpose discoPingPurpose, size int, res *ipnstate.PingResult, cb func(*ipnstate.PingResult)) { + if runtime.GOOS == "js" { + return + } + epDisco := de.disco.Load() + if epDisco == nil { + return + } + if purpose != pingCLI { + st, ok := de.endpointState[ep] + if !ok { + // Shouldn't happen. But don't ping an endpoint that's + // not active for us. + de.c.logf("magicsock: disco: [unexpected] attempt to ping no longer live endpoint %v", ep) + return + } + st.lastPing = now + } + + txid := stun.NewTxID() + de.sentPing[txid] = sentPing{ + to: ep, + at: now, + timer: time.AfterFunc(pingTimeoutDuration, func() { de.discoPingTimeout(txid) }), + purpose: purpose, + res: res, + cb: cb, + } + + logLevel := discoLog + if purpose == pingHeartbeat { + logLevel = discoVerboseLog + } + go de.sendDiscoPing(ep, epDisco.key, txid, size, logLevel) +} + +// sendDiscoPingsLocked starts pinging all of ep's endpoints. +func (de *endpoint) sendDiscoPingsLocked(now mono.Time, sendCallMeMaybe bool) { + de.lastFullPing = now + var sentAny bool + for ep, st := range de.endpointState { + if st.shouldDeleteLocked() { + de.deleteEndpointLocked("sendPingsLocked", ep) + continue + } + if runtime.GOOS == "js" { + continue + } + if !st.lastPing.IsZero() && now.Sub(st.lastPing) < discoPingInterval { + continue + } + + firstPing := !sentAny + sentAny = true + + if firstPing && sendCallMeMaybe { + de.c.dlogf("[v1] magicsock: disco: send, starting discovery for %v (%v)", de.publicKey.ShortString(), de.discoShort()) + } + + de.startDiscoPingLocked(ep, now, pingDiscovery, 0, nil, nil) + } + derpAddr := de.derpAddr + if sentAny && sendCallMeMaybe && derpAddr.IsValid() { + // Have our magicsock.Conn figure out its STUN endpoint (if + // it doesn't know already) and then send a CallMeMaybe + // message to our peer via DERP informing them that we've + // sent so our firewall ports are probably open and now + // would be a good time for them to connect. + go de.c.enqueueCallMeMaybe(derpAddr, de) + } +} + +// sendWireGuardOnlyPingsLocked evaluates all available addresses for +// a WireGuard only endpoint and initates an ICMP ping for useable +// addresses. +func (de *endpoint) sendWireGuardOnlyPingsLocked(now mono.Time) { + if runtime.GOOS == "js" { + return + } + + // Normally the we only send pings at a low rate as the decision to start + // sending a ping sets bestAddrAtUntil with a reasonable time to keep trying + // that address, however, if that code changed we may want to be sure that + // we don't ever send excessive pings to avoid impact to the client/user. + if !now.After(de.lastFullPing.Add(10 * time.Second)) { + return + } + de.lastFullPing = now + + for ipp := range de.endpointState { + if ipp.Addr().Is4() && de.c.noV4.Load() { + continue + } + if ipp.Addr().Is6() && de.c.noV6.Load() { + continue + } + + go de.sendWireGuardOnlyPing(ipp, now) + } +} + +// sendWireGuardOnlyPing sends a ICMP ping to a WireGuard only address to +// discover the latency. +func (de *endpoint) sendWireGuardOnlyPing(ipp netip.AddrPort, now mono.Time) { + ctx, cancel := context.WithTimeout(de.c.connCtx, 5*time.Second) + defer cancel() + + de.setLastPing(ipp, now) + + addr := &net.IPAddr{ + IP: net.IP(ipp.Addr().AsSlice()), + Zone: ipp.Addr().Zone(), + } + + p := de.c.getPinger() + if p == nil { + de.c.logf("[v2] magicsock: sendWireGuardOnlyPingLocked: pinger is nil") + return + } + + latency, err := p.Send(ctx, addr, nil) + if err != nil { + de.c.logf("[v2] magicsock: sendWireGuardOnlyPingLocked: %s", err) + return + } + + de.mu.Lock() + defer de.mu.Unlock() + + state, ok := de.endpointState[ipp] + if !ok { + return + } + state.addPongReplyLocked(pongReply{ + latency: latency, + pongAt: now, + from: ipp, + pongSrc: netip.AddrPort{}, // We don't know this. + }) +} + +// setLastPing sets lastPing on the endpointState to now. +func (de *endpoint) setLastPing(ipp netip.AddrPort, now mono.Time) { + de.mu.Lock() + defer de.mu.Unlock() + state, ok := de.endpointState[ipp] + if !ok { + return + } + state.lastPing = now +} + +// updateFromNode updates the endpoint based on a tailcfg.Node from a NetMap +// update. +func (de *endpoint) updateFromNode(n tailcfg.NodeView, heartbeatDisabled bool) { + if !n.Valid() { + panic("nil node when updating endpoint") + } + de.mu.Lock() + defer de.mu.Unlock() + + de.heartbeatDisabled = heartbeatDisabled + de.expired = n.Expired() + + epDisco := de.disco.Load() + var discoKey key.DiscoPublic + if epDisco != nil { + discoKey = epDisco.key + } + + if discoKey != n.DiscoKey() { + de.c.logf("[v1] magicsock: disco: node %s changed from %s to %s", de.publicKey.ShortString(), discoKey, n.DiscoKey()) + de.disco.Store(&endpointDisco{ + key: n.DiscoKey(), + short: n.DiscoKey().ShortString(), + }) + de.debugUpdates.Add(EndpointChange{ + When: time.Now(), + What: "updateFromNode-resetLocked", + }) + de.resetLocked() + } + if n.DERP() == "" { + if de.derpAddr.IsValid() { + de.debugUpdates.Add(EndpointChange{ + When: time.Now(), + What: "updateFromNode-remove-DERP", + From: de.derpAddr, + }) + } + de.derpAddr = netip.AddrPort{} + } else { + newDerp, _ := netip.ParseAddrPort(n.DERP()) + if de.derpAddr != newDerp { + de.debugUpdates.Add(EndpointChange{ + When: time.Now(), + What: "updateFromNode-DERP", + From: de.derpAddr, + To: newDerp, + }) + } + de.derpAddr = newDerp + } + + de.setEndpointsLocked(addrPortsFromStringsView{n.Endpoints()}) +} + +// addrPortsFromStringsView converts a view of AddrPort strings +// to a view-like thing of netip.AddrPort. +// TODO(bradfitz): change the type of tailcfg.Node.Endpoint. +type addrPortsFromStringsView struct { + views.Slice[string] +} + +func (a addrPortsFromStringsView) At(i int) netip.AddrPort { + ap, _ := netip.ParseAddrPort(a.Slice.At(i)) + return ap // or the zero value on error +} + +func (de *endpoint) setEndpointsLocked(eps interface { + LenIter() []struct{} + At(i int) netip.AddrPort +}) { + for _, st := range de.endpointState { + st.index = indexSentinelDeleted // assume deleted until updated in next loop + } + + var newIpps []netip.AddrPort + for i := range eps.LenIter() { + if i > math.MaxInt16 { + // Seems unlikely. + break + } + ipp := eps.At(i) + if !ipp.IsValid() { + de.c.logf("magicsock: bogus netmap endpoint from %v", eps) + continue + } + if st, ok := de.endpointState[ipp]; ok { + st.index = int16(i) + } else { + de.endpointState[ipp] = &endpointState{index: int16(i)} + newIpps = append(newIpps, ipp) + } + } + if len(newIpps) > 0 { + de.debugUpdates.Add(EndpointChange{ + When: time.Now(), + What: "updateFromNode-new-Endpoints", + To: newIpps, + }) + } + + // Now delete anything unless it's still in the network map or + // was a recently discovered endpoint. + for ep, st := range de.endpointState { + if st.shouldDeleteLocked() { + de.deleteEndpointLocked("updateFromNode", ep) + } + } +} + +// addCandidateEndpoint adds ep as an endpoint to which we should send +// future pings. If there is an existing endpointState for ep, and forRxPingTxID +// matches the last received ping TxID, this function reports true, otherwise +// false. +// +// This is called once we've already verified that we got a valid +// discovery message from de via ep. +func (de *endpoint) addCandidateEndpoint(ep netip.AddrPort, forRxPingTxID stun.TxID) (duplicatePing bool) { + de.mu.Lock() + defer de.mu.Unlock() + + if st, ok := de.endpointState[ep]; ok { + duplicatePing = forRxPingTxID == st.lastGotPingTxID + if !duplicatePing { + st.lastGotPingTxID = forRxPingTxID + } + if st.lastGotPing.IsZero() { + // Already-known endpoint from the network map. + return duplicatePing + } + st.lastGotPing = time.Now() + return duplicatePing + } + + // Newly discovered endpoint. Exciting! + de.c.dlogf("[v1] magicsock: disco: adding %v as candidate endpoint for %v (%s)", ep, de.discoShort(), de.publicKey.ShortString()) + de.endpointState[ep] = &endpointState{ + lastGotPing: time.Now(), + lastGotPingTxID: forRxPingTxID, + } + + // If for some reason this gets very large, do some cleanup. + if size := len(de.endpointState); size > 100 { + for ep, st := range de.endpointState { + if st.shouldDeleteLocked() { + de.deleteEndpointLocked("addCandidateEndpoint", ep) + } + } + size2 := len(de.endpointState) + de.c.dlogf("[v1] magicsock: disco: addCandidateEndpoint pruned %v candidate set from %v to %v entries", size, size2) + } + return false +} + +// clearBestAddrLocked clears the bestAddr and related fields such that future +// packets will re-evaluate the best address to send to next. +// +// de.mu must be held. +func (de *endpoint) clearBestAddrLocked() { + de.bestAddr = addrLatency{} + de.bestAddrAt = 0 + de.trustBestAddrUntil = 0 +} + +// noteBadEndpoint marks ipp as a bad endpoint that would need to be +// re-evaluated before future use, this should be called for example if a send +// to ipp fails due to a host unreachable error or similar. +func (de *endpoint) noteBadEndpoint(ipp netip.AddrPort) { + de.mu.Lock() + defer de.mu.Unlock() + + de.clearBestAddrLocked() + + if st, ok := de.endpointState[ipp]; ok { + st.clear() + } +} + +// noteConnectivityChange is called when connectivity changes enough +// that we should question our earlier assumptions about which paths +// work. +func (de *endpoint) noteConnectivityChange() { + de.mu.Lock() + defer de.mu.Unlock() + + de.clearBestAddrLocked() + + for k := range de.endpointState { + de.endpointState[k].clear() + } +} + +// handlePongConnLocked handles a Pong message (a reply to an earlier ping). +// It should be called with the Conn.mu held. +// +// It reports whether m.TxID corresponds to a ping that this endpoint sent. +func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src netip.AddrPort) (knownTxID bool) { + de.mu.Lock() + defer de.mu.Unlock() + + isDerp := src.Addr() == tailcfg.DerpMagicIPAddr + + sp, ok := de.sentPing[m.TxID] + if !ok { + // This is not a pong for a ping we sent. + return false + } + knownTxID = true // for naked returns below + de.removeSentDiscoPingLocked(m.TxID, sp) + + now := mono.Now() + latency := now.Sub(sp.at) + + if !isDerp { + st, ok := de.endpointState[sp.to] + if !ok { + // This is no longer an endpoint we care about. + return + } + + de.c.peerMap.setNodeKeyForIPPort(src, de.publicKey) + + st.addPongReplyLocked(pongReply{ + latency: latency, + pongAt: now, + from: src, + pongSrc: m.Src, + }) + } + + if sp.purpose != pingHeartbeat { + de.c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got pong tx=%x latency=%v pong.src=%v%v", de.c.discoShort, de.discoShort(), de.publicKey.ShortString(), src, m.TxID[:6], latency.Round(time.Millisecond), m.Src, logger.ArgWriter(func(bw *bufio.Writer) { + if sp.to != src { + fmt.Fprintf(bw, " ping.to=%v", sp.to) + } + })) + } + + // Currently only CLI ping uses this callback. + if sp.cb != nil { + if sp.purpose == pingCLI { + de.c.populateCLIPingResponseLocked(sp.res, latency, sp.to) + } + go sp.cb(sp.res) + } + + // Promote this pong response to our current best address if it's lower latency. + // TODO(bradfitz): decide how latency vs. preference order affects decision + if !isDerp { + thisPong := addrLatency{sp.to, latency} + if betterAddr(thisPong, de.bestAddr) { + de.c.logf("magicsock: disco: node %v %v now using %v", de.publicKey.ShortString(), de.discoShort(), sp.to) + de.debugUpdates.Add(EndpointChange{ + When: time.Now(), + What: "handlePingLocked-bestAddr-update", + From: de.bestAddr, + To: thisPong, + }) + de.bestAddr = thisPong + } + if de.bestAddr.AddrPort == thisPong.AddrPort { + de.debugUpdates.Add(EndpointChange{ + When: time.Now(), + What: "handlePingLocked-bestAddr-latency", + From: de.bestAddr, + To: thisPong, + }) + de.bestAddr.latency = latency + de.bestAddrAt = now + de.trustBestAddrUntil = now.Add(trustUDPAddrDuration) + } + } + return +} + +// addrLatency is an IPPort with an associated latency. +type addrLatency struct { + netip.AddrPort + latency time.Duration +} + +func (a addrLatency) String() string { + return a.AddrPort.String() + "@" + a.latency.String() +} + +// betterAddr reports whether a is a better addr to use than b. +func betterAddr(a, b addrLatency) bool { + if a.AddrPort == b.AddrPort { + return false + } + if !b.IsValid() { + return true + } + if !a.IsValid() { + return false + } + + // Each address starts with a set of points (from 0 to 100) that + // represents how much faster they are than the highest-latency + // endpoint. For example, if a has latency 200ms and b has latency + // 190ms, then a starts with 0 points and b starts with 5 points since + // it's 5% faster. + var aPoints, bPoints int + if a.latency > b.latency && a.latency > 0 { + bPoints = int(100 - ((b.latency * 100) / a.latency)) + } else if b.latency > 0 { + aPoints = int(100 - ((a.latency * 100) / b.latency)) + } + + // Prefer private IPs over public IPs as long as the latencies are + // roughly equivalent, since it's less likely that a user will have to + // pay for the bandwidth in a cloud environment. + // + // Additionally, prefer any loopback address strongly over non-loopback + // addresses. + if a.Addr().IsLoopback() { + aPoints += 50 + } else if a.Addr().IsPrivate() { + aPoints += 20 + } + if b.Addr().IsLoopback() { + bPoints += 50 + } else if b.Addr().IsPrivate() { + bPoints += 20 + } + + // Prefer IPv6 for being a bit more robust, as long as + // the latencies are roughly equivalent. + if a.Addr().Is6() { + aPoints += 10 + } + if b.Addr().Is6() { + bPoints += 10 + } + + // Don't change anything if the latency improvement is less than 1%; we + // want a bit of "stickiness" (a.k.a. hysteresis) to avoid flapping if + // there's two roughly-equivalent endpoints. + // + // Points are essentially the percentage improvement of latency vs. the + // slower endpoint; absent any boosts from private IPs, IPv6, etc., a + // will be a better address than b by a fraction of 1% or less if + // aPoints <= 1 and bPoints == 0. + if aPoints <= 1 && bPoints == 0 { + return false + } + + return aPoints > bPoints +} + +// handleCallMeMaybe handles a CallMeMaybe discovery message via +// DERP. The contract for use of this message is that the peer has +// already sent to us via UDP, so their stateful firewall should be +// open. Now we can Ping back and make it through. +func (de *endpoint) handleCallMeMaybe(m *disco.CallMeMaybe) { + if runtime.GOOS == "js" { + // Nothing to do on js/wasm if we can't send UDP packets anyway. + return + } + de.mu.Lock() + defer de.mu.Unlock() + + now := time.Now() + for ep := range de.isCallMeMaybeEP { + de.isCallMeMaybeEP[ep] = false // mark for deletion + } + var newEPs []netip.AddrPort + for _, ep := range m.MyNumber { + if ep.Addr().Is6() && ep.Addr().IsLinkLocalUnicast() { + // We send these out, but ignore them for now. + // TODO: teach the ping code to ping on all interfaces + // for these. + continue + } + mak.Set(&de.isCallMeMaybeEP, ep, true) + if es, ok := de.endpointState[ep]; ok { + es.callMeMaybeTime = now + } else { + de.endpointState[ep] = &endpointState{callMeMaybeTime: now} + newEPs = append(newEPs, ep) + } + } + if len(newEPs) > 0 { + de.debugUpdates.Add(EndpointChange{ + When: time.Now(), + What: "handleCallMeMaybe-new-endpoints", + To: newEPs, + }) + + de.c.dlogf("[v1] magicsock: disco: call-me-maybe from %v %v added new endpoints: %v", + de.publicKey.ShortString(), de.discoShort(), + logger.ArgWriter(func(w *bufio.Writer) { + for i, ep := range newEPs { + if i > 0 { + w.WriteString(", ") + } + w.WriteString(ep.String()) + } + })) + } + + // Delete any prior CallMeMaybe endpoints that weren't included + // in this message. + for ep, want := range de.isCallMeMaybeEP { + if !want { + delete(de.isCallMeMaybeEP, ep) + de.deleteEndpointLocked("handleCallMeMaybe", ep) + } + } + + // Zero out all the lastPing times to force sendPingsLocked to send new ones, + // even if it's been less than 5 seconds ago. + for _, st := range de.endpointState { + st.lastPing = 0 + } + de.sendDiscoPingsLocked(mono.Now(), false) +} + +func (de *endpoint) populatePeerStatus(ps *ipnstate.PeerStatus) { + de.mu.Lock() + defer de.mu.Unlock() + + ps.Relay = de.c.derpRegionCodeOfIDLocked(int(de.derpAddr.Port())) + + if de.lastSend.IsZero() { + return + } + + now := mono.Now() + ps.LastWrite = de.lastSend.WallTime() + ps.Active = now.Sub(de.lastSend) < sessionActiveTimeout + + if udpAddr, derpAddr, _ := de.addrForSendLocked(now); udpAddr.IsValid() && !derpAddr.IsValid() { + ps.CurAddr = udpAddr.String() + } +} + +// stopAndReset stops timers associated with de and resets its state back to zero. +// It's called when a discovery endpoint is no longer present in the +// NetworkMap, or when magicsock is transitioning from running to +// stopped state (via SetPrivateKey(zero)) +func (de *endpoint) stopAndReset() { + atomic.AddInt64(&de.numStopAndResetAtomic, 1) + de.mu.Lock() + defer de.mu.Unlock() + + if closing := de.c.closing.Load(); !closing { + de.c.logf("[v1] magicsock: doing cleanup for discovery key %s", de.discoShort()) + } + + de.debugUpdates.Add(EndpointChange{ + When: time.Now(), + What: "stopAndReset-resetLocked", + }) + de.resetLocked() + if de.heartBeatTimer != nil { + de.heartBeatTimer.Stop() + de.heartBeatTimer = nil + } +} + +// resetLocked clears all the endpoint's p2p state, reverting it to a +// DERP-only endpoint. It does not stop the endpoint's heartbeat +// timer, if one is running. +func (de *endpoint) resetLocked() { + de.lastSend = 0 + de.lastFullPing = 0 + de.clearBestAddrLocked() + for _, es := range de.endpointState { + es.lastPing = 0 + } + for txid, sp := range de.sentPing { + de.removeSentDiscoPingLocked(txid, sp) + } +} + +func (de *endpoint) numStopAndReset() int64 { + return atomic.LoadInt64(&de.numStopAndResetAtomic) +} + +func (de *endpoint) setDERPHome(regionID uint16) { + de.mu.Lock() + defer de.mu.Unlock() + de.derpAddr = netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, uint16(regionID)) +} diff --git a/vendor/tailscale.com/wgengine/magicsock/endpoint_default.go b/vendor/tailscale.com/wgengine/magicsock/endpoint_default.go new file mode 100644 index 0000000000..1ed6e5e0e2 --- /dev/null +++ b/vendor/tailscale.com/wgengine/magicsock/endpoint_default.go @@ -0,0 +1,22 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !js && !wasm && !plan9 + +package magicsock + +import ( + "errors" + "syscall" +) + +// errHOSTUNREACH wraps unix.EHOSTUNREACH in an interface type to pass to +// errors.Is while avoiding an allocation per call. +var errHOSTUNREACH error = syscall.EHOSTUNREACH + +// isBadEndpointErr checks if err is one which is known to report that an +// endpoint can no longer be sent to. It is not exhaustive, and for unknown +// errors always reports false. +func isBadEndpointErr(err error) bool { + return errors.Is(err, errHOSTUNREACH) +} diff --git a/vendor/tailscale.com/wgengine/magicsock/endpoint_stub.go b/vendor/tailscale.com/wgengine/magicsock/endpoint_stub.go new file mode 100644 index 0000000000..a209c352bf --- /dev/null +++ b/vendor/tailscale.com/wgengine/magicsock/endpoint_stub.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build wasm || plan9 + +package magicsock + +// isBadEndpointErr checks if err is one which is known to report that an +// endpoint can no longer be sent to. It is not exhaustive, but covers known +// cases. +func isBadEndpointErr(err error) bool { + return false +} diff --git a/vendor/tailscale.com/wgengine/magicsock/endpoint_tracker.go b/vendor/tailscale.com/wgengine/magicsock/endpoint_tracker.go new file mode 100644 index 0000000000..5caddd1a06 --- /dev/null +++ b/vendor/tailscale.com/wgengine/magicsock/endpoint_tracker.go @@ -0,0 +1,248 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "net/netip" + "slices" + "sync" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/tempfork/heap" + "tailscale.com/util/mak" + "tailscale.com/util/set" +) + +const ( + // endpointTrackerLifetime is how long we continue advertising an + // endpoint after we last see it. This is intentionally chosen to be + // slightly longer than a full netcheck period. + endpointTrackerLifetime = 5*time.Minute + 10*time.Second + + // endpointTrackerMaxPerAddr is how many cached addresses we track for + // a given netip.Addr. This allows e.g. restricting the number of STUN + // endpoints we cache (which usually have the same netip.Addr but + // different ports). + // + // The value of 6 is chosen because we can advertise up to 3 endpoints + // based on the STUN IP: + // 1. The STUN endpoint itself (EndpointSTUN) + // 2. The STUN IP with the local Tailscale port (EndpointSTUN4LocalPort) + // 3. The STUN IP with a portmapped port (EndpointPortmapped) + // + // Storing 6 endpoints in the cache means we can store up to 2 previous + // sets of endpoints. + endpointTrackerMaxPerAddr = 6 +) + +// endpointTrackerEntry is an entry in an endpointHeap that stores the state of +// a given cached endpoint. +type endpointTrackerEntry struct { + // endpoint is the cached endpoint. + endpoint tailcfg.Endpoint + // until is the time until which this endpoint is being cached. + until time.Time + // index is the index within the containing endpointHeap. + index int +} + +// endpointHeap is an ordered heap of endpointTrackerEntry structs, ordered in +// ascending order by the 'until' expiry time (i.e. oldest first). +type endpointHeap []*endpointTrackerEntry + +var _ heap.Interface[*endpointTrackerEntry] = (*endpointHeap)(nil) + +// Len implements heap.Interface. +func (eh endpointHeap) Len() int { return len(eh) } + +// Less implements heap.Interface. +func (eh endpointHeap) Less(i, j int) bool { + // We want to store items so that the lowest item in the heap is the + // oldest, so that heap.Pop()-ing from the endpointHeap will remove the + // oldest entry. + return eh[i].until.Before(eh[j].until) +} + +// Swap implements heap.Interface. +func (eh endpointHeap) Swap(i, j int) { + eh[i], eh[j] = eh[j], eh[i] + eh[i].index = i + eh[j].index = j +} + +// Push implements heap.Interface. +func (eh *endpointHeap) Push(item *endpointTrackerEntry) { + n := len(*eh) + item.index = n + *eh = append(*eh, item) +} + +// Pop implements heap.Interface. +func (eh *endpointHeap) Pop() *endpointTrackerEntry { + old := *eh + n := len(old) + item := old[n-1] + old[n-1] = nil // avoid memory leak + item.index = -1 // for safety + *eh = old[0 : n-1] + return item +} + +// Min returns a pointer to the minimum element in the heap, without removing +// it. Since this is a min-heap ordered by the 'until' field, this returns the +// chronologically "earliest" element in the heap. +// +// Len() must be non-zero. +func (eh endpointHeap) Min() *endpointTrackerEntry { + return eh[0] +} + +// endpointTracker caches endpoints that are advertised to peers. This allows +// peers to still reach this node if there's a temporary endpoint flap; rather +// than withdrawing an endpoint and then re-advertising it the next time we run +// a netcheck, we keep advertising the endpoint until it's not present for a +// defined timeout. +// +// See tailscale/tailscale#7877 for more information. +type endpointTracker struct { + mu sync.Mutex + endpoints map[netip.Addr]*endpointHeap +} + +// update takes as input the current sent of discovered endpoints and the +// current time, and returns the set of endpoints plus any previous-cached and +// non-expired endpoints that should be advertised to peers. +func (et *endpointTracker) update(now time.Time, eps []tailcfg.Endpoint) (epsPlusCached []tailcfg.Endpoint) { + var inputEps set.Slice[netip.AddrPort] + for _, ep := range eps { + inputEps.Add(ep.Addr) + } + + et.mu.Lock() + defer et.mu.Unlock() + + // Extend endpoints that already exist in the cache. We do this before + // we remove expired endpoints, below, so we don't remove something + // that would otherwise have survived by extending. + until := now.Add(endpointTrackerLifetime) + for _, ep := range eps { + et.extendLocked(ep, until) + } + + // Now that we've extended existing endpoints, remove everything that + // has expired. + et.removeExpiredLocked(now) + + // Add entries from the input set of endpoints into the cache; we do + // this after removing expired ones so that we can store as many as + // possible, with space freed by the entries removed after expiry. + for _, ep := range eps { + et.addLocked(now, ep, until) + } + + // Finally, add entries to the return array that aren't already there. + epsPlusCached = eps + for _, heap := range et.endpoints { + for _, ep := range *heap { + // If the endpoint was in the input list, or has expired, skip it. + if inputEps.Contains(ep.endpoint.Addr) { + continue + } else if now.After(ep.until) { + // Defense-in-depth; should never happen since + // we removed expired entries above, but ignore + // it anyway. + continue + } + + // We haven't seen this endpoint; add to the return array + epsPlusCached = append(epsPlusCached, ep.endpoint) + } + } + + return epsPlusCached +} + +// extendLocked will update the expiry time of the provided endpoint in the +// cache, if it is present. If it is not present, nothing will be done. +// +// et.mu must be held. +func (et *endpointTracker) extendLocked(ep tailcfg.Endpoint, until time.Time) { + key := ep.Addr.Addr() + epHeap, found := et.endpoints[key] + if !found { + return + } + + // Find the entry for this exact address; this loop is quick since we + // bound the number of items in the heap. + // + // TODO(andrew): this means we iterate over the entire heap once per + // endpoint; even if the heap is small, if we have a lot of input + // endpoints this can be expensive? + for i, entry := range *epHeap { + if entry.endpoint == ep { + entry.until = until + heap.Fix(epHeap, i) + return + } + } +} + +// addLocked will store the provided endpoint(s) in the cache for a fixed +// period of time, ensuring that the size of the endpoint cache remains below +// the maximum. +// +// et.mu must be held. +func (et *endpointTracker) addLocked(now time.Time, ep tailcfg.Endpoint, until time.Time) { + key := ep.Addr.Addr() + + // Create or get the heap for this endpoint's addr + epHeap := et.endpoints[key] + if epHeap == nil { + epHeap = new(endpointHeap) + mak.Set(&et.endpoints, key, epHeap) + } + + // Find the entry for this exact address; this loop is quick + // since we bound the number of items in the heap. + found := slices.ContainsFunc(*epHeap, func(v *endpointTrackerEntry) bool { + return v.endpoint == ep + }) + if !found { + // Add address to heap; either the endpoint is new, or the heap + // was newly-created and thus empty. + heap.Push(epHeap, &endpointTrackerEntry{endpoint: ep, until: until}) + } + + // Now that we've added everything, pop from our heap until we're below + // the limit. This is a min-heap, so popping removes the lowest (and + // thus oldest) endpoint. + for epHeap.Len() > endpointTrackerMaxPerAddr { + heap.Pop(epHeap) + } +} + +// removeExpired will remove all expired entries from the cache. +// +// et.mu must be held. +func (et *endpointTracker) removeExpiredLocked(now time.Time) { + for k, epHeap := range et.endpoints { + // The minimum element is oldest/earliest endpoint; repeatedly + // pop from the heap while it's in the past. + for epHeap.Len() > 0 { + minElem := epHeap.Min() + if now.After(minElem.until) { + heap.Pop(epHeap) + } else { + break + } + } + + if epHeap.Len() == 0 { + // Free up space in the map by removing the empty heap. + delete(et.endpoints, k) + } + } +} diff --git a/vendor/tailscale.com/wgengine/magicsock/magicsock.go b/vendor/tailscale.com/wgengine/magicsock/magicsock.go index dea8b2d97b..cdb793e39b 100644 --- a/vendor/tailscale.com/wgengine/magicsock/magicsock.go +++ b/vendor/tailscale.com/wgengine/magicsock/magicsock.go @@ -7,21 +7,13 @@ package magicsock import ( "bufio" - "bytes" "context" - crand "crypto/rand" - "encoding/binary" "errors" "fmt" - "hash/fnv" "io" - "math" - "math/rand" "net" "net/netip" - "reflect" "runtime" - "sort" "strconv" "strings" "sync" @@ -32,19 +24,15 @@ import ( "go4.org/mem" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" - "tailscale.com/control/controlclient" - "tailscale.com/derp" - "tailscale.com/derp/derphttp" + + "tailscale.com/control/controlknobs" "tailscale.com/disco" "tailscale.com/envknob" "tailscale.com/health" "tailscale.com/hostinfo" "tailscale.com/ipn/ipnstate" - "tailscale.com/logtail/backoff" "tailscale.com/net/connstats" - "tailscale.com/net/dnscache" "tailscale.com/net/interfaces" - "tailscale.com/net/netaddr" "tailscale.com/net/netcheck" "tailscale.com/net/neterror" "tailscale.com/net/netmon" @@ -54,7 +42,6 @@ import ( "tailscale.com/net/portmapper" "tailscale.com/net/sockstats" "tailscale.com/net/stun" - "tailscale.com/net/tsaddr" "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tstime" @@ -64,13 +51,13 @@ import ( "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/types/nettype" + "tailscale.com/types/views" "tailscale.com/util/clientmetric" "tailscale.com/util/mak" "tailscale.com/util/ringbuffer" "tailscale.com/util/set" - "tailscale.com/util/sysresources" + "tailscale.com/util/testenv" "tailscale.com/util/uniq" - "tailscale.com/version" "tailscale.com/wgengine/capture" ) @@ -89,192 +76,6 @@ const ( socketBufferSize = 7 << 20 ) -// useDerpRoute reports whether magicsock should enable the DERP -// return path optimization (Issue 150). -func useDerpRoute() bool { - if b, ok := debugUseDerpRoute().Get(); ok { - return b - } - ob := controlclient.DERPRouteFlag() - if v, ok := ob.Get(); ok { - return v - } - return true // as of 1.21.x -} - -// peerInfo is all the information magicsock tracks about a particular -// peer. -type peerInfo struct { - ep *endpoint // always non-nil. - // ipPorts is an inverted version of peerMap.byIPPort (below), so - // that when we're deleting this node, we can rapidly find out the - // keys that need deleting from peerMap.byIPPort without having to - // iterate over every IPPort known for any peer. - ipPorts map[netip.AddrPort]bool -} - -func newPeerInfo(ep *endpoint) *peerInfo { - return &peerInfo{ - ep: ep, - ipPorts: map[netip.AddrPort]bool{}, - } -} - -// peerMap is an index of peerInfos by node (WireGuard) key, disco -// key, and discovered ip:port endpoints. -// -// Doesn't do any locking, all access must be done with Conn.mu held. -type peerMap struct { - byNodeKey map[key.NodePublic]*peerInfo - byIPPort map[netip.AddrPort]*peerInfo - - // nodesOfDisco contains the set of nodes that are using a - // DiscoKey. Usually those sets will be just one node. - nodesOfDisco map[key.DiscoPublic]map[key.NodePublic]bool -} - -func newPeerMap() peerMap { - return peerMap{ - byNodeKey: map[key.NodePublic]*peerInfo{}, - byIPPort: map[netip.AddrPort]*peerInfo{}, - nodesOfDisco: map[key.DiscoPublic]map[key.NodePublic]bool{}, - } -} - -// nodeCount returns the number of nodes currently in m. -func (m *peerMap) nodeCount() int { - return len(m.byNodeKey) -} - -// anyEndpointForDiscoKey reports whether there exists any -// peers in the netmap with dk as their DiscoKey. -func (m *peerMap) anyEndpointForDiscoKey(dk key.DiscoPublic) bool { - return len(m.nodesOfDisco[dk]) > 0 -} - -// endpointForNodeKey returns the endpoint for nk, or nil if -// nk is not known to us. -func (m *peerMap) endpointForNodeKey(nk key.NodePublic) (ep *endpoint, ok bool) { - if nk.IsZero() { - return nil, false - } - if info, ok := m.byNodeKey[nk]; ok { - return info.ep, true - } - return nil, false -} - -// endpointForIPPort returns the endpoint for the peer we -// believe to be at ipp, or nil if we don't know of any such peer. -func (m *peerMap) endpointForIPPort(ipp netip.AddrPort) (ep *endpoint, ok bool) { - if info, ok := m.byIPPort[ipp]; ok { - return info.ep, true - } - return nil, false -} - -// forEachEndpoint invokes f on every endpoint in m. -func (m *peerMap) forEachEndpoint(f func(ep *endpoint)) { - for _, pi := range m.byNodeKey { - f(pi.ep) - } -} - -// forEachEndpointWithDiscoKey invokes f on every endpoint in m that has the -// provided DiscoKey until f returns false or there are no endpoints left to -// iterate. -func (m *peerMap) forEachEndpointWithDiscoKey(dk key.DiscoPublic, f func(*endpoint) (keepGoing bool)) { - for nk := range m.nodesOfDisco[dk] { - pi, ok := m.byNodeKey[nk] - if !ok { - // Unexpected. Data structures would have to - // be out of sync. But we don't have a logger - // here to log [unexpected], so just skip. - // Maybe log later once peerMap is merged back - // into Conn. - continue - } - if !f(pi.ep) { - return - } - } -} - -// upsertEndpoint stores endpoint in the peerInfo for -// ep.publicKey, and updates indexes. m must already have a -// tailcfg.Node for ep.publicKey. -func (m *peerMap) upsertEndpoint(ep *endpoint, oldDiscoKey key.DiscoPublic) { - if m.byNodeKey[ep.publicKey] == nil { - m.byNodeKey[ep.publicKey] = newPeerInfo(ep) - } - epDisco := ep.disco.Load() - if epDisco == nil || oldDiscoKey != epDisco.key { - delete(m.nodesOfDisco[oldDiscoKey], ep.publicKey) - } - if ep.isWireguardOnly { - // If the peer is a WireGuard only peer, add all of its endpoints. - - // TODO(raggi,catzkorn): this could mean that if a "isWireguardOnly" - // peer has, say, 192.168.0.2 and so does a tailscale peer, the - // wireguard one will win. That may not be the outcome that we want - - // perhaps we should prefer bestAddr.AddrPort if it is set? - // see tailscale/tailscale#7994 - for ipp := range ep.endpointState { - m.setNodeKeyForIPPort(ipp, ep.publicKey) - } - - return - } - set := m.nodesOfDisco[epDisco.key] - if set == nil { - set = map[key.NodePublic]bool{} - m.nodesOfDisco[epDisco.key] = set - } - set[ep.publicKey] = true -} - -// setNodeKeyForIPPort makes future peer lookups by ipp return the -// same endpoint as a lookup by nk. -// -// This should only be called with a fully verified mapping of ipp to -// nk, because calling this function defines the endpoint we hand to -// WireGuard for packets received from ipp. -func (m *peerMap) setNodeKeyForIPPort(ipp netip.AddrPort, nk key.NodePublic) { - if pi := m.byIPPort[ipp]; pi != nil { - delete(pi.ipPorts, ipp) - delete(m.byIPPort, ipp) - } - if pi, ok := m.byNodeKey[nk]; ok { - pi.ipPorts[ipp] = true - m.byIPPort[ipp] = pi - } -} - -// deleteEndpoint deletes the peerInfo associated with ep, and -// updates indexes. -func (m *peerMap) deleteEndpoint(ep *endpoint) { - if ep == nil { - return - } - ep.stopAndReset() - - epDisco := ep.disco.Load() - - pi := m.byNodeKey[ep.publicKey] - if epDisco != nil { - delete(m.nodesOfDisco[epDisco.key], ep.publicKey) - } - delete(m.byNodeKey, ep.publicKey) - if pi == nil { - // Kneejerk paranoia from earlier issue 2801. - // Unexpected. But no logger plumbed here to log so. - return - } - for ip := range pi.ipPorts { - delete(m.byIPPort, ip) - } -} - // A Conn routes UDP packets and actively manages a list of its endpoints. type Conn struct { // This block mirrors the contents and field order of the Options @@ -287,6 +88,7 @@ type Conn struct { testOnlyPacketListener nettype.PacketListener noteRecvActivity func(key.NodePublic) // or nil, see Options.NoteRecvActivity netMon *netmon.Monitor // or nil + controlKnobs *controlknobs.Knobs // or nil // ================================================================ // No locking required to access these fields, either because @@ -319,10 +121,6 @@ type Conn struct { // port mappings from NAT devices. portMapper *portmapper.Client - // stunReceiveFunc holds the current STUN packet processing func. - // Its Loaded value is always non-nil. - stunReceiveFunc syncs.AtomicValue[func(p []byte, fromAddr netip.AddrPort)] - // derpRecvCh is used by receiveDERP to read DERP messages. // It must have buffer size > 0; see issue 3736. derpRecvCh chan derpReadResult @@ -367,6 +165,9 @@ type Conn struct { // port is the preferred port from opts.Port; 0 means auto. port atomic.Uint32 + // peerMTUEnabled is whether path MTU discovery to peers is enabled. + peerMTUEnabled atomic.Bool + // stats maintains per-connection counters. stats atomic.Pointer[connstats.Statistics] @@ -435,10 +236,10 @@ type Conn struct { // WireGuard. These are not used to filter inbound or outbound // traffic at all, but only to track what state can be cleaned up // in other maps below that are keyed by peer public key. - peerSet map[key.NodePublic]struct{} + peerSet set.Set[key.NodePublic] - // nodeOfDisco tracks the networkmap Node entity for each peer - // discovery key. + // peerMap tracks the networkmap Node entity for each peer + // by node key, node ID, and discovery key. peerMap peerMap // discoInfo is the state for an active DiscoKey. @@ -459,14 +260,16 @@ type Conn struct { // magicsock could do with any complexity reduction it can get. netInfoLast *tailcfg.NetInfo - derpMap *tailcfg.DERPMap // nil (or zero regions/nodes) means DERP is disabled - netMap *netmap.NetworkMap - privateKey key.NodePrivate // WireGuard private key for this node - everHadKey bool // whether we ever had a non-zero private key - myDerp int // nearest DERP region ID; 0 means none/unknown - derpStarted chan struct{} // closed on first connection to DERP; for tests & cleaner Close - activeDerp map[int]activeDerp // DERP regionID -> connection to a node in that region - prevDerp map[int]*syncs.WaitGroupChan + derpMap *tailcfg.DERPMap // nil (or zero regions/nodes) means DERP is disabled + peers views.Slice[tailcfg.NodeView] // from last SetNetworkMap update + lastFlags debugFlags // at time of last SetNetworkMap + firstAddrForTest netip.Addr // from last SetNetworkMap update; for tests only + privateKey key.NodePrivate // WireGuard private key for this node + everHadKey bool // whether we ever had a non-zero private key + myDerp int // nearest DERP region ID; 0 means none/unknown + derpStarted chan struct{} // closed on first connection to DERP; for tests & cleaner Close + activeDerp map[int]activeDerp // DERP regionID -> connection to a node in that region + prevDerp map[int]*syncs.WaitGroupChan // derpRoute contains optional alternate routes to use as an // optimization instead of contacting a peer via their home @@ -503,48 +306,6 @@ func (c *Conn) dlogf(format string, a ...any) { } } -// derpRoute is a route entry for a public key, saying that a certain -// peer should be available at DERP node derpID, as long as the -// current connection for that derpID is dc. (but dc should not be -// used to write directly; it's owned by the read/write loops) -type derpRoute struct { - derpID int - dc *derphttp.Client // don't use directly; see comment above -} - -// removeDerpPeerRoute removes a DERP route entry previously added by addDerpPeerRoute. -func (c *Conn) removeDerpPeerRoute(peer key.NodePublic, derpID int, dc *derphttp.Client) { - c.mu.Lock() - defer c.mu.Unlock() - r2 := derpRoute{derpID, dc} - if r, ok := c.derpRoute[peer]; ok && r == r2 { - delete(c.derpRoute, peer) - } -} - -// addDerpPeerRoute adds a DERP route entry, noting that peer was seen -// on DERP node derpID, at least on the connection identified by dc. -// See issue 150 for details. -func (c *Conn) addDerpPeerRoute(peer key.NodePublic, derpID int, dc *derphttp.Client) { - c.mu.Lock() - defer c.mu.Unlock() - mak.Set(&c.derpRoute, peer, derpRoute{derpID, dc}) -} - -var derpMagicIPAddr = netip.MustParseAddr(tailcfg.DerpMagicIP) - -// activeDerp contains fields for an active DERP connection. -type activeDerp struct { - c *derphttp.Client - cancel context.CancelFunc - writeCh chan<- derpWriteRequest - // lastWrite is the time of the last request for its write - // channel (currently even if there was no write). - // It is always non-nil and initialized to a non-zero Time. - lastWrite *time.Time - createTime time.Time -} - // Options contains options for Listen. type Options struct { // Logf optionally provides a log function to use. @@ -586,6 +347,10 @@ type Options struct { // NetMon is the network monitor to use. // With one, the portmapper won't be used. NetMon *netmon.Monitor + + // ControlKnobs are the set of control knobs to use. + // If nil, they're ignored and not updated. + ControlKnobs *controlknobs.Knobs } func (o *Options) logf() logger.Logf { @@ -646,13 +411,14 @@ func newConn() *Conn { func NewConn(opts Options) (*Conn, error) { c := newConn() c.port.Store(uint32(opts.Port)) + c.controlKnobs = opts.ControlKnobs c.logf = opts.logf() c.epFunc = opts.endpointsFunc() c.derpActiveFunc = opts.derpActiveFunc() c.idleFunc = opts.IdleFunc c.testOnlyPacketListener = opts.TestOnlyPacketListener c.noteRecvActivity = opts.NoteRecvActivity - c.portMapper = portmapper.NewClient(logger.WithPrefix(c.logf, "portmapper: "), opts.NetMon, nil, c.onPortMapChanged) + c.portMapper = portmapper.NewClient(logger.WithPrefix(c.logf, "portmapper: "), opts.NetMon, nil, opts.ControlKnobs, c.onPortMapChanged) if opts.NetMon != nil { c.portMapper.SetGatewayLookupFunc(opts.NetMon.GatewayAndSelfIP) } @@ -665,17 +431,20 @@ func NewConn(opts Options) (*Conn, error) { c.connCtx, c.connCtxCancel = context.WithCancel(context.Background()) c.donec = c.connCtx.Done() c.netChecker = &netcheck.Client{ - Logf: logger.WithPrefix(c.logf, "netcheck: "), - NetMon: c.netMon, - GetSTUNConn4: func() netcheck.STUNConn { return &c.pconn4 }, - GetSTUNConn6: func() netcheck.STUNConn { return &c.pconn6 }, + Logf: logger.WithPrefix(c.logf, "netcheck: "), + NetMon: c.netMon, + SendPacket: func(b []byte, ap netip.AddrPort) (int, error) { + ok, err := c.sendUDP(ap, b) + if !ok { + return 0, err + } + return len(b), err + }, SkipExternalNetwork: inTest(), PortMapper: c.portMapper, UseDNSCache: true, } - c.ignoreSTUNPackets() - if d4, err := c.listenRawDisco("ip4"); err == nil { c.logf("[v1] using BPF disco receiver for IPv4") c.closeDisco4 = d4 @@ -701,11 +470,6 @@ func (c *Conn) InstallCaptureHook(cb capture.Callback) { c.captureHook.Store(cb) } -// ignoreSTUNPackets sets a STUN packet processing func that does nothing. -func (c *Conn) ignoreSTUNPackets() { - c.stunReceiveFunc.Store(func([]byte, netip.AddrPort) {}) -} - // doPeriodicSTUN is called (in a new goroutine) by // periodicReSTUNTimer when periodic STUNs are active. func (c *Conn) doPeriodicSTUN() { c.ReSTUN("periodic") } @@ -851,9 +615,6 @@ func (c *Conn) updateNetInfo(ctx context.Context) (*netcheck.Report, error) { ctx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() - c.stunReceiveFunc.Store(c.netChecker.ReceiveSTUNPacket) - defer c.ignoreSTUNPackets() - report, err := c.netChecker.GetReport(ctx, dm) if err != nil { return nil, err @@ -893,55 +654,13 @@ func (c *Conn) updateNetInfo(ctx context.Context) (*netcheck.Report, error) { if !c.setNearestDERP(ni.PreferredDERP) { ni.PreferredDERP = 0 } - - // TODO: set link type + ni.FirewallMode = hostinfo.FirewallMode() c.callNetInfoCallback(ni) return report, nil } -var processStartUnixNano = time.Now().UnixNano() - -// pickDERPFallback returns a non-zero but deterministic DERP node to -// connect to. This is only used if netcheck couldn't find the -// nearest one (for instance, if UDP is blocked and thus STUN latency -// checks aren't working). -// -// c.mu must NOT be held. -func (c *Conn) pickDERPFallback() int { - c.mu.Lock() - defer c.mu.Unlock() - - if !c.wantDerpLocked() { - return 0 - } - ids := c.derpMap.RegionIDs() - if len(ids) == 0 { - // No DERP regions in non-nil map. - return 0 - } - - // TODO: figure out which DERP region most of our peers are using, - // and use that region as our fallback. - // - // If we already had selected something in the past and it has any - // peers, we want to stay on it. If there are no peers at all, - // stay on whatever DERP we previously picked. If we need to pick - // one and have no peer info, pick a region randomly. - // - // We used to do the above for legacy clients, but never updated - // it for disco. - - if c.myDerp != 0 { - return c.myDerp - } - - h := fnv.New64() - fmt.Fprintf(h, "%p/%d", c, processStartUnixNano) // arbitrary - return ids[rand.New(rand.NewSource(int64(h.Sum64()))).Intn(len(ids))] -} - -// callNetInfoCallback calls the NetInfo callback (if previously +// callNetInfoCallback calls the callback (if previously // registered with SetNetInfoCallback) if ni has substantially changed // since the last state. // @@ -975,6 +694,13 @@ func (c *Conn) addValidDiscoPathForTest(nodeKey key.NodePublic, addr netip.AddrP c.peerMap.setNodeKeyForIPPort(addr, nodeKey) } +// SetNetInfoCallback sets the func to be called whenever the network conditions +// change. +// +// At most one func can be registered; the most recent one replaces any previous +// registration. +// +// This is called by LocalBackend. func (c *Conn) SetNetInfoCallback(fn func(*tailcfg.NetInfo)) { if fn == nil { panic("nil NetInfoCallback") @@ -1006,7 +732,7 @@ func (c *Conn) LastRecvActivityOfNodeKey(nk key.NodePublic) string { } // Ping handles a "tailscale ping" CLI query. -func (c *Conn) Ping(peer *tailcfg.Node, res *ipnstate.PingResult, cb func(*ipnstate.PingResult)) { +func (c *Conn) Ping(peer tailcfg.NodeView, res *ipnstate.PingResult, size int, cb func(*ipnstate.PingResult)) { c.mu.Lock() defer c.mu.Unlock() if c.privateKey.IsZero() { @@ -1014,29 +740,29 @@ func (c *Conn) Ping(peer *tailcfg.Node, res *ipnstate.PingResult, cb func(*ipnst cb(res) return } - if len(peer.Addresses) > 0 { - res.NodeIP = peer.Addresses[0].Addr().String() + if peer.Addresses().Len() > 0 { + res.NodeIP = peer.Addresses().At(0).Addr().String() } - res.NodeName = peer.Name // prefer DNS name + res.NodeName = peer.Name() // prefer DNS name if res.NodeName == "" { - res.NodeName = peer.Hostinfo.Hostname() // else hostname + res.NodeName = peer.Hostinfo().Hostname() // else hostname } else { res.NodeName, _, _ = strings.Cut(res.NodeName, ".") } - ep, ok := c.peerMap.endpointForNodeKey(peer.Key) + ep, ok := c.peerMap.endpointForNodeKey(peer.Key()) if !ok { res.Err = "unknown peer" cb(res) return } - ep.cliPing(res, cb) + ep.cliPing(res, size, cb) } // c.mu must be held func (c *Conn) populateCLIPingResponseLocked(res *ipnstate.PingResult, latency time.Duration, ep netip.AddrPort) { res.LatencySeconds = latency.Seconds() - if ep.Addr() != derpMagicIPAddr { + if ep.Addr() != tailcfg.DerpMagicIPAddr { res.Endpoint = ep.String() return } @@ -1048,13 +774,13 @@ func (c *Conn) populateCLIPingResponseLocked(res *ipnstate.PingResult, latency t // GetEndpointChanges returns the most recent changes for a particular // endpoint. The returned EndpointChange structs are for debug use only and // there are no guarantees about order, size, or content. -func (c *Conn) GetEndpointChanges(peer *tailcfg.Node) ([]EndpointChange, error) { +func (c *Conn) GetEndpointChanges(peer tailcfg.NodeView) ([]EndpointChange, error) { c.mu.Lock() if c.privateKey.IsZero() { c.mu.Unlock() return nil, fmt.Errorf("tailscaled stopped") } - ep, ok := c.peerMap.endpointForNodeKey(peer.Key) + ep, ok := c.peerMap.endpointForNodeKey(peer.Key()) c.mu.Unlock() if !ok { @@ -1064,79 +790,11 @@ func (c *Conn) GetEndpointChanges(peer *tailcfg.Node) ([]EndpointChange, error) return ep.debugUpdates.GetAll(), nil } -func (c *Conn) derpRegionCodeLocked(regionID int) string { - if c.derpMap == nil { - return "" - } - if dr, ok := c.derpMap.Regions[regionID]; ok { - return dr.RegionCode - } - return "" -} - // DiscoPublicKey returns the discovery public key. func (c *Conn) DiscoPublicKey() key.DiscoPublic { return c.discoPublic } -// c.mu must NOT be held. -func (c *Conn) setNearestDERP(derpNum int) (wantDERP bool) { - c.mu.Lock() - defer c.mu.Unlock() - if !c.wantDerpLocked() { - c.myDerp = 0 - health.SetMagicSockDERPHome(0) - return false - } - if derpNum == c.myDerp { - // No change. - return true - } - if c.myDerp != 0 && derpNum != 0 { - metricDERPHomeChange.Add(1) - } - c.myDerp = derpNum - health.SetMagicSockDERPHome(derpNum) - - if c.privateKey.IsZero() { - // No private key yet, so DERP connections won't come up anyway. - // Return early rather than ultimately log a couple lines of noise. - return true - } - - // On change, notify all currently connected DERP servers and - // start connecting to our home DERP if we are not already. - dr := c.derpMap.Regions[derpNum] - if dr == nil { - c.logf("[unexpected] magicsock: derpMap.Regions[%v] is nil", derpNum) - } else { - c.logf("magicsock: home is now derp-%v (%v)", derpNum, c.derpMap.Regions[derpNum].RegionCode) - } - for i, ad := range c.activeDerp { - go ad.c.NotePreferred(i == c.myDerp) - } - c.goDerpConnect(derpNum) - return true -} - -// startDerpHomeConnectLocked starts connecting to our DERP home, if any. -// -// c.mu must be held. -func (c *Conn) startDerpHomeConnectLocked() { - c.goDerpConnect(c.myDerp) -} - -// goDerpConnect starts a goroutine to start connecting to the given -// DERP node. -// -// c.mu may be held, but does not need to be. -func (c *Conn) goDerpConnect(node int) { - if node == 0 { - return - } - go c.derpWriteChanOfAddr(netip.AddrPortFrom(derpMagicIPAddr, uint16(node)), key.NodePublic{}) -} - // determineEndpoints returns the machine's endpoint addresses. It // does a STUN lookup (via netcheck) to determine its public address. // @@ -1210,8 +868,6 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro addAddr(ipp(nr.GlobalV6), tailcfg.EndpointSTUN) } - c.ignoreSTUNPackets() - // Update our set of endpoints by adding any endpoints that we // previously found but haven't expired yet. This also updates the // cache with the set of endpoints discovered in this function. @@ -1330,6 +986,8 @@ var errDropDerpPacket = errors.New("too many DERP packets queued; dropping") var errNoUDP = errors.New("no UDP available on platform") +var errUnsupportedConnType = errors.New("unsupported connection type") + var ( // This acts as a compile-time check for our usage of ipv6.Message in // batchingUDPConn for both IPv6 and IPv4 operations. @@ -1408,7 +1066,7 @@ func (c *Conn) sendUDPStd(addr netip.AddrPort, b []byte) (sent bool, err error) // IPv6 address when the local machine doesn't have IPv6 support // returns (false, nil); it's not an error, but nothing was sent. func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte) (sent bool, err error) { - if addr.Addr() != derpMagicIPAddr { + if addr.Addr() != tailcfg.DerpMagicIPAddr { return c.sendUDP(addr, b) } @@ -1440,519 +1098,98 @@ func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte) (s } } -var ( - bufferedDerpWrites int - bufferedDerpWritesOnce sync.Once -) - -// bufferedDerpWritesBeforeDrop returns how many packets writes can be queued -// up the DERP client to write on the wire before we start dropping. -func bufferedDerpWritesBeforeDrop() int { - // For mobile devices, always return the previous minimum value of 32; - // we can do this outside the sync.Once to avoid that overhead. - if runtime.GOOS == "ios" || runtime.GOOS == "android" { - return 32 - } - - bufferedDerpWritesOnce.Do(func() { - // Some rough sizing: for the previous fixed value of 32, the - // total consumed memory can be: - // = numDerpRegions * messages/region * sizeof(message) - // - // For sake of this calculation, assume 100 DERP regions; at - // time of writing (2023-04-03), we have 24. - // - // A reasonable upper bound for the worst-case average size of - // a message is a *disco.CallMeMaybe message with 16 endpoints; - // since sizeof(netip.AddrPort) = 32, that's 512 bytes. Thus: - // = 100 * 32 * 512 - // = 1638400 (1.6MiB) - // - // On a reasonably-small node with 4GiB of memory that's - // connected to each region and handling a lot of load, 1.6MiB - // is about 0.04% of the total system memory. - // - // For sake of this calculation, then, let's double that memory - // usage to 0.08% and scale based on total system memory. - // - // For a 16GiB Linux box, this should buffer just over 256 - // messages. - systemMemory := sysresources.TotalMemory() - memoryUsable := float64(systemMemory) * 0.0008 - - const ( - theoreticalDERPRegions = 100 - messageMaximumSizeBytes = 512 - ) - bufferedDerpWrites = int(memoryUsable / (theoreticalDERPRegions * messageMaximumSizeBytes)) - - // Never drop below the previous minimum value. - if bufferedDerpWrites < 32 { - bufferedDerpWrites = 32 - } - }) - return bufferedDerpWrites +type receiveBatch struct { + msgs []ipv6.Message } -// derpWriteChanOfAddr returns a DERP client for fake UDP addresses that -// represent DERP servers, creating them as necessary. For real UDP -// addresses, it returns nil. -// -// If peer is non-zero, it can be used to find an active reverse -// path, without using addr. -func (c *Conn) derpWriteChanOfAddr(addr netip.AddrPort, peer key.NodePublic) chan<- derpWriteRequest { - if addr.Addr() != derpMagicIPAddr { - return nil - } - regionID := int(addr.Port()) - - if c.networkDown() { - return nil +func (c *Conn) getReceiveBatchForBuffs(buffs [][]byte) *receiveBatch { + batch := c.receiveBatchPool.Get().(*receiveBatch) + for i := range buffs { + batch.msgs[i].Buffers[0] = buffs[i] + batch.msgs[i].OOB = batch.msgs[i].OOB[:cap(batch.msgs[i].OOB)] } + return batch +} - c.mu.Lock() - defer c.mu.Unlock() - if !c.wantDerpLocked() || c.closed { - return nil - } - if c.derpMap == nil || c.derpMap.Regions[regionID] == nil { - return nil - } - if c.privateKey.IsZero() { - c.logf("magicsock: DERP lookup of %v with no private key; ignoring", addr) - return nil +func (c *Conn) putReceiveBatch(batch *receiveBatch) { + for i := range batch.msgs { + batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers, OOB: batch.msgs[i].OOB} } + c.receiveBatchPool.Put(batch) +} - // See if we have a connection open to that DERP node ID - // first. If so, might as well use it. (It's a little - // arbitrary whether we use this one vs. the reverse route - // below when we have both.) - ad, ok := c.activeDerp[regionID] - if ok { - *ad.lastWrite = time.Now() - c.setPeerLastDerpLocked(peer, regionID, regionID) - return ad.writeCh - } - - // If we don't have an open connection to the peer's home DERP - // node, see if we have an open connection to a DERP node - // where we'd heard from that peer already. For instance, - // perhaps peer's home is Frankfurt, but they dialed our home DERP - // node in SF to reach us, so we can reply to them using our - // SF connection rather than dialing Frankfurt. (Issue 150) - if !peer.IsZero() && useDerpRoute() { - if r, ok := c.derpRoute[peer]; ok { - if ad, ok := c.activeDerp[r.derpID]; ok && ad.c == r.dc { - c.setPeerLastDerpLocked(peer, r.derpID, regionID) - *ad.lastWrite = time.Now() - return ad.writeCh - } - } - } +// receiveIPv4 creates an IPv4 ReceiveFunc reading from c.pconn4. +func (c *Conn) receiveIPv4() conn.ReceiveFunc { + return c.mkReceiveFunc(&c.pconn4, &health.ReceiveIPv4, metricRecvDataIPv4) +} - why := "home-keep-alive" - if !peer.IsZero() { - why = peer.ShortString() - } - c.logf("magicsock: adding connection to derp-%v for %v", regionID, why) +// receiveIPv6 creates an IPv6 ReceiveFunc reading from c.pconn6. +func (c *Conn) receiveIPv6() conn.ReceiveFunc { + return c.mkReceiveFunc(&c.pconn6, &health.ReceiveIPv6, metricRecvDataIPv6) +} - firstDerp := false - if c.activeDerp == nil { - firstDerp = true - c.activeDerp = make(map[int]activeDerp) - c.prevDerp = make(map[int]*syncs.WaitGroupChan) - } +// mkReceiveFunc creates a ReceiveFunc reading from ruc. +// The provided healthItem and metric are updated if non-nil. +func (c *Conn) mkReceiveFunc(ruc *RebindingUDPConn, healthItem *health.ReceiveFuncStats, metric *clientmetric.Metric) conn.ReceiveFunc { + // epCache caches an IPPort->endpoint for hot flows. + var epCache ippEndpointCache - // Note that derphttp.NewRegionClient does not dial the server - // (it doesn't block) so it is safe to do under the c.mu lock. - dc := derphttp.NewRegionClient(c.privateKey, c.logf, c.netMon, func() *tailcfg.DERPRegion { - // Warning: it is not legal to acquire - // magicsock.Conn.mu from this callback. - // It's run from derphttp.Client.connect (via Send, etc) - // and the lock ordering rules are that magicsock.Conn.mu - // must be acquired before derphttp.Client.mu. - // See https://github.com/tailscale/tailscale/issues/3726 - if c.connCtx.Err() != nil { - // We're closing anyway; return nil to stop dialing. - return nil + return func(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) { + if healthItem != nil { + healthItem.Enter() + defer healthItem.Exit() } - derpMap := c.derpMapAtomic.Load() - if derpMap == nil { - return nil + if ruc == nil { + panic("nil RebindingUDPConn") } - return derpMap.Regions[regionID] - }) - - dc.SetCanAckPings(true) - dc.NotePreferred(c.myDerp == regionID) - dc.SetAddressFamilySelector(derpAddrFamSelector{c}) - dc.DNSCache = dnscache.Get() - - ctx, cancel := context.WithCancel(c.connCtx) - ch := make(chan derpWriteRequest, bufferedDerpWritesBeforeDrop()) - - ad.c = dc - ad.writeCh = ch - ad.cancel = cancel - ad.lastWrite = new(time.Time) - *ad.lastWrite = time.Now() - ad.createTime = time.Now() - c.activeDerp[regionID] = ad - metricNumDERPConns.Set(int64(len(c.activeDerp))) - c.logActiveDerpLocked() - c.setPeerLastDerpLocked(peer, regionID, regionID) - c.scheduleCleanStaleDerpLocked() - - // Build a startGate for the derp reader+writer - // goroutines, so they don't start running until any - // previous generation is closed. - startGate := syncs.ClosedChan() - if prev := c.prevDerp[regionID]; prev != nil { - startGate = prev.DoneChan() - } - // And register a WaitGroup(Chan) for this generation. - wg := syncs.NewWaitGroupChan() - wg.Add(2) - c.prevDerp[regionID] = wg - - if firstDerp { - startGate = c.derpStarted - go func() { - dc.Connect(ctx) - close(c.derpStarted) - c.muCond.Broadcast() - }() - } - - go c.runDerpReader(ctx, addr, dc, wg, startGate) - go c.runDerpWriter(ctx, dc, ch, wg, startGate) - go c.derpActiveFunc() - - return ad.writeCh -} -// setPeerLastDerpLocked notes that peer is now being written to via -// the provided DERP regionID, and that the peer advertises a DERP -// home region ID of homeID. -// -// If there's any change, it logs. -// -// c.mu must be held. -func (c *Conn) setPeerLastDerpLocked(peer key.NodePublic, regionID, homeID int) { - if peer.IsZero() { - return - } - old := c.peerLastDerp[peer] - if old == regionID { - return - } - c.peerLastDerp[peer] = regionID + batch := c.getReceiveBatchForBuffs(buffs) + defer c.putReceiveBatch(batch) + for { + numMsgs, err := ruc.ReadBatch(batch.msgs[:len(buffs)], 0) + if err != nil { + if neterror.PacketWasTruncated(err) { + continue + } + return 0, err + } - var newDesc string - switch { - case regionID == homeID && regionID == c.myDerp: - newDesc = "shared home" - case regionID == homeID: - newDesc = "their home" - case regionID == c.myDerp: - newDesc = "our home" - case regionID != homeID: - newDesc = "alt" - } - if old == 0 { - c.logf("[v1] magicsock: derp route for %s set to derp-%d (%s)", peer.ShortString(), regionID, newDesc) - } else { - c.logf("[v1] magicsock: derp route for %s changed from derp-%d => derp-%d (%s)", peer.ShortString(), old, regionID, newDesc) + reportToCaller := false + for i, msg := range batch.msgs[:numMsgs] { + if msg.N == 0 { + sizes[i] = 0 + continue + } + ipp := msg.Addr.(*net.UDPAddr).AddrPort() + if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &epCache); ok { + if metric != nil { + metric.Add(1) + } + eps[i] = ep + sizes[i] = msg.N + reportToCaller = true + } else { + sizes[i] = 0 + } + } + if reportToCaller { + return numMsgs, nil + } + } } } -// derpReadResult is the type sent by runDerpClient to ReceiveIPv4 -// when a DERP packet is available. +// receiveIP is the shared bits of ReceiveIPv4 and ReceiveIPv6. // -// Notably, it doesn't include the derp.ReceivedPacket because we -// don't want to give the receiver access to the aliased []byte. To -// get at the packet contents they need to call copyBuf to copy it -// out, which also releases the buffer. -type derpReadResult struct { - regionID int - n int // length of data received - src key.NodePublic - // copyBuf is called to copy the data to dst. It returns how - // much data was copied, which will be n if dst is large - // enough. copyBuf can only be called once. - // If copyBuf is nil, that's a signal from the sender to ignore - // this message. - copyBuf func(dst []byte) int -} - -// runDerpReader runs in a goroutine for the life of a DERP -// connection, handling received packets. -func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr netip.AddrPort, dc *derphttp.Client, wg *syncs.WaitGroupChan, startGate <-chan struct{}) { - defer wg.Decr() - defer dc.Close() - - select { - case <-startGate: - case <-ctx.Done(): - return +// ok is whether this read should be reported up to wireguard-go (our +// caller). +func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache) (ep *endpoint, ok bool) { + if stun.Is(b) { + c.netChecker.ReceiveSTUNPacket(b, ipp) + return nil, false } - - didCopy := make(chan struct{}, 1) - regionID := int(derpFakeAddr.Port()) - res := derpReadResult{regionID: regionID} - var pkt derp.ReceivedPacket - res.copyBuf = func(dst []byte) int { - n := copy(dst, pkt.Data) - didCopy <- struct{}{} - return n - } - - defer health.SetDERPRegionConnectedState(regionID, false) - defer health.SetDERPRegionHealth(regionID, "") - - // peerPresent is the set of senders we know are present on this - // connection, based on messages we've received from the server. - peerPresent := map[key.NodePublic]bool{} - bo := backoff.NewBackoff(fmt.Sprintf("derp-%d", regionID), c.logf, 5*time.Second) - var lastPacketTime time.Time - var lastPacketSrc key.NodePublic - - for { - msg, connGen, err := dc.RecvDetail() - if err != nil { - health.SetDERPRegionConnectedState(regionID, false) - // Forget that all these peers have routes. - for peer := range peerPresent { - delete(peerPresent, peer) - c.removeDerpPeerRoute(peer, regionID, dc) - } - if err == derphttp.ErrClientClosed { - return - } - if c.networkDown() { - c.logf("[v1] magicsock: derp.Recv(derp-%d): network down, closing", regionID) - return - } - select { - case <-ctx.Done(): - return - default: - } - - c.logf("magicsock: [%p] derp.Recv(derp-%d): %v", dc, regionID, err) - - // If our DERP connection broke, it might be because our network - // conditions changed. Start that check. - c.ReSTUN("derp-recv-error") - - // Back off a bit before reconnecting. - bo.BackOff(ctx, err) - select { - case <-ctx.Done(): - return - default: - } - continue - } - bo.BackOff(ctx, nil) // reset - - now := time.Now() - if lastPacketTime.IsZero() || now.Sub(lastPacketTime) > 5*time.Second { - health.NoteDERPRegionReceivedFrame(regionID) - lastPacketTime = now - } - - switch m := msg.(type) { - case derp.ServerInfoMessage: - health.SetDERPRegionConnectedState(regionID, true) - health.SetDERPRegionHealth(regionID, "") // until declared otherwise - c.logf("magicsock: derp-%d connected; connGen=%v", regionID, connGen) - continue - case derp.ReceivedPacket: - pkt = m - res.n = len(m.Data) - res.src = m.Source - if logDerpVerbose() { - c.logf("magicsock: got derp-%v packet: %q", regionID, m.Data) - } - // If this is a new sender we hadn't seen before, remember it and - // register a route for this peer. - if res.src != lastPacketSrc { // avoid map lookup w/ high throughput single peer - lastPacketSrc = res.src - if _, ok := peerPresent[res.src]; !ok { - peerPresent[res.src] = true - c.addDerpPeerRoute(res.src, regionID, dc) - } - } - case derp.PingMessage: - // Best effort reply to the ping. - pingData := [8]byte(m) - go func() { - if err := dc.SendPong(pingData); err != nil { - c.logf("magicsock: derp-%d SendPong error: %v", regionID, err) - } - }() - continue - case derp.HealthMessage: - health.SetDERPRegionHealth(regionID, m.Problem) - case derp.PeerGoneMessage: - switch m.Reason { - case derp.PeerGoneReasonDisconnected: - // Do nothing. - case derp.PeerGoneReasonNotHere: - metricRecvDiscoDERPPeerNotHere.Add(1) - c.logf("[unexpected] magicsock: derp-%d does not know about peer %s, removing route", - regionID, key.NodePublic(m.Peer).ShortString()) - default: - metricRecvDiscoDERPPeerGoneUnknown.Add(1) - c.logf("[unexpected] magicsock: derp-%d peer %s gone, reason %v, removing route", - regionID, key.NodePublic(m.Peer).ShortString(), m.Reason) - } - c.removeDerpPeerRoute(key.NodePublic(m.Peer), regionID, dc) - default: - // Ignore. - continue - } - - select { - case <-ctx.Done(): - return - case c.derpRecvCh <- res: - } - - select { - case <-ctx.Done(): - return - case <-didCopy: - continue - } - } -} - -type derpWriteRequest struct { - addr netip.AddrPort - pubKey key.NodePublic - b []byte // copied; ownership passed to receiver -} - -// runDerpWriter runs in a goroutine for the life of a DERP -// connection, handling received packets. -func (c *Conn) runDerpWriter(ctx context.Context, dc *derphttp.Client, ch <-chan derpWriteRequest, wg *syncs.WaitGroupChan, startGate <-chan struct{}) { - defer wg.Decr() - select { - case <-startGate: - case <-ctx.Done(): - return - } - - for { - select { - case <-ctx.Done(): - return - case wr := <-ch: - err := dc.Send(wr.pubKey, wr.b) - if err != nil { - c.logf("magicsock: derp.Send(%v): %v", wr.addr, err) - metricSendDERPError.Add(1) - } else { - metricSendDERP.Add(1) - } - } - } -} - -type receiveBatch struct { - msgs []ipv6.Message -} - -func (c *Conn) getReceiveBatchForBuffs(buffs [][]byte) *receiveBatch { - batch := c.receiveBatchPool.Get().(*receiveBatch) - for i := range buffs { - batch.msgs[i].Buffers[0] = buffs[i] - batch.msgs[i].OOB = batch.msgs[i].OOB[:cap(batch.msgs[i].OOB)] - } - return batch -} - -func (c *Conn) putReceiveBatch(batch *receiveBatch) { - for i := range batch.msgs { - batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers, OOB: batch.msgs[i].OOB} - } - c.receiveBatchPool.Put(batch) -} - -// receiveIPv4 creates an IPv4 ReceiveFunc reading from c.pconn4. -func (c *Conn) receiveIPv4() conn.ReceiveFunc { - return c.mkReceiveFunc(&c.pconn4, &health.ReceiveIPv4, metricRecvDataIPv4) -} - -// receiveIPv6 creates an IPv6 ReceiveFunc reading from c.pconn6. -func (c *Conn) receiveIPv6() conn.ReceiveFunc { - return c.mkReceiveFunc(&c.pconn6, &health.ReceiveIPv6, metricRecvDataIPv6) -} - -// mkReceiveFunc creates a ReceiveFunc reading from ruc. -// The provided healthItem and metric are updated if non-nil. -func (c *Conn) mkReceiveFunc(ruc *RebindingUDPConn, healthItem *health.ReceiveFuncStats, metric *clientmetric.Metric) conn.ReceiveFunc { - // epCache caches an IPPort->endpoint for hot flows. - var epCache ippEndpointCache - - return func(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) { - if healthItem != nil { - healthItem.Enter() - defer healthItem.Exit() - } - if ruc == nil { - panic("nil RebindingUDPConn") - } - - batch := c.getReceiveBatchForBuffs(buffs) - defer c.putReceiveBatch(batch) - for { - numMsgs, err := ruc.ReadBatch(batch.msgs[:len(buffs)], 0) - if err != nil { - if neterror.PacketWasTruncated(err) { - continue - } - return 0, err - } - - reportToCaller := false - for i, msg := range batch.msgs[:numMsgs] { - if msg.N == 0 { - sizes[i] = 0 - continue - } - ipp := msg.Addr.(*net.UDPAddr).AddrPort() - if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &epCache); ok { - if metric != nil { - metric.Add(1) - } - eps[i] = ep - sizes[i] = msg.N - reportToCaller = true - } else { - sizes[i] = 0 - } - } - if reportToCaller { - return numMsgs, nil - } - } - } -} - -// receiveIP is the shared bits of ReceiveIPv4 and ReceiveIPv6. -// -// ok is whether this read should be reported up to wireguard-go (our -// caller). -func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache) (ep *endpoint, ok bool) { - if stun.Is(b) { - c.stunReceiveFunc.Load()(b, ipp) - return nil, false - } - if c.handleDiscoMessage(b, ipp, key.NodePublic{}, discoRXPathUDP) { - return nil, false + if c.handleDiscoMessage(b, ipp, key.NodePublic{}, discoRXPathUDP) { + return nil, false } if !c.havePrivateKey.Load() { // If we have no private key, we're logged out or @@ -1974,69 +1211,13 @@ func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache) cache.gen = de.numStopAndReset() ep = de } - ep.noteRecvActivity() + ep.noteRecvActivity(ipp) if stats := c.stats.Load(); stats != nil { stats.UpdateRxPhysical(ep.nodeAddr, ipp, len(b)) } return ep, true } -func (c *connBind) receiveDERP(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) { - health.ReceiveDERP.Enter() - defer health.ReceiveDERP.Exit() - - for dm := range c.derpRecvCh { - if c.isClosed() { - break - } - n, ep := c.processDERPReadResult(dm, buffs[0]) - if n == 0 { - // No data read occurred. Wait for another packet. - continue - } - metricRecvDataDERP.Add(1) - sizes[0] = n - eps[0] = ep - return 1, nil - } - return 0, net.ErrClosed -} - -func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *endpoint) { - if dm.copyBuf == nil { - return 0, nil - } - var regionID int - n, regionID = dm.n, dm.regionID - ncopy := dm.copyBuf(b) - if ncopy != n { - err := fmt.Errorf("received DERP packet of length %d that's too big for WireGuard buf size %d", n, ncopy) - c.logf("magicsock: %v", err) - return 0, nil - } - - ipp := netip.AddrPortFrom(derpMagicIPAddr, uint16(regionID)) - if c.handleDiscoMessage(b[:n], ipp, dm.src, discoRXPathDERP) { - return 0, nil - } - - var ok bool - c.mu.Lock() - ep, ok = c.peerMap.endpointForNodeKey(dm.src) - c.mu.Unlock() - if !ok { - // We don't know anything about this node key, nothing to - // record or process. - return 0, nil - } - - ep.noteRecvActivity() - if stats := c.stats.Load(); stats != nil { - stats.UpdateRxPhysical(ep.nodeAddr, ipp, dm.n) - } - return n, ep -} - // discoLogLevel controls the verbosity of discovery log messages. type discoLogLevel int @@ -2062,7 +1243,7 @@ var debugIPv4DiscoPingPenalty = envknob.RegisterDuration("TS_DISCO_PONG_IPV4_DEL // The dstKey should only be non-zero if the dstDisco key // unambiguously maps to exactly one peer. func (c *Conn) sendDiscoMessage(dst netip.AddrPort, dstKey key.NodePublic, dstDisco key.DiscoPublic, m disco.Message, logLevel discoLogLevel) (sent bool, err error) { - isDERP := dst.Addr() == derpMagicIPAddr + isDERP := dst.Addr() == tailcfg.DerpMagicIPAddr if _, isPong := m.(*disco.Pong); isPong && !isDERP && dst.Addr().Is4() { time.Sleep(debugIPv4DiscoPingPenalty()) } @@ -2072,10 +1253,6 @@ func (c *Conn) sendDiscoMessage(dst netip.AddrPort, dstKey key.NodePublic, dstDi c.mu.Unlock() return false, errConnClosed } - var nonce [disco.NonceLen]byte - if _, err := crand.Read(nonce[:]); err != nil { - panic(err) // worth dying for - } pkt := make([]byte, 0, 512) // TODO: size it correctly? pool? if it matters. pkt = append(pkt, disco.Magic...) pkt = c.discoPublic.AppendTo(pkt) @@ -2097,7 +1274,7 @@ func (c *Conn) sendDiscoMessage(dst netip.AddrPort, dstKey key.NodePublic, dstDi if !dstKey.IsZero() { node = dstKey.ShortString() } - c.dlogf("[v1] magicsock: disco: %v->%v (%v, %v) sent %v", c.discoShort, dstDisco.ShortString(), node, derpStr(dst.String()), disco.MessageSummary(m)) + c.dlogf("[v1] magicsock: disco: %v->%v (%v, %v) sent %v len %v\n", c.discoShort, dstDisco.ShortString(), node, derpStr(dst.String()), disco.MessageSummary(m), len(pkt)) } if isDERP { metricSentDiscoDERP.Add(1) @@ -2116,40 +1293,12 @@ func (c *Conn) sendDiscoMessage(dst netip.AddrPort, dstKey key.NodePublic, dstDi // Can't send. (e.g. no IPv6 locally) } else { if !c.networkDown() { - c.logf("magicsock: disco: failed to send %T to %v: %v", m, dst, err) + c.logf("magicsock: disco: failed to send %v to %v: %v", disco.MessageSummary(m), dst, err) } } return sent, err } -// discoPcapFrame marshals the bytes for a pcap record that describe a -// disco frame. -// -// Warning: Alloc garbage. Acceptable while capturing. -func discoPcapFrame(src netip.AddrPort, derpNodeSrc key.NodePublic, payload []byte) []byte { - var ( - b bytes.Buffer - flag uint8 - ) - b.Grow(128) // Most disco frames will probably be smaller than this. - - if src.Addr() == derpMagicIPAddr { - flag |= 0x01 - } - b.WriteByte(flag) // 1b: flag - - derpSrc := derpNodeSrc.Raw32() - b.Write(derpSrc[:]) // 32b: derp public key - binary.Write(&b, binary.LittleEndian, uint16(src.Port())) // 2b: port - addr, _ := src.Addr().MarshalBinary() - binary.Write(&b, binary.LittleEndian, uint16(len(addr))) // 2b: len(addr) - b.Write(addr) // Xb: addr - binary.Write(&b, binary.LittleEndian, uint16(len(payload))) // 2b: len(payload) - b.Write(payload) // Xb: payload - - return b.Bytes() -} - type discoRXPath string const ( @@ -2193,7 +1342,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke return } if debugDisco() { - c.logf("magicsock: disco: got disco-looking frame from %v via %s", sender.ShortString(), via) + c.logf("magicsock: disco: got disco-looking frame from %v via %s len %v", sender.ShortString(), via, len(msg)) } if c.privateKey.IsZero() { // Ignore disco messages when we're stopped. @@ -2239,7 +1388,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke // Emit information about the disco frame into the pcap stream // if a capture hook is installed. if cb := c.captureHook.Load(); cb != nil { - cb(capture.PathDisco, time.Now(), discoPcapFrame(src, derpNodeSrc, payload), packet.CaptureMeta{}) + cb(capture.PathDisco, time.Now(), disco.ToPCAPFrame(src, derpNodeSrc, payload), packet.CaptureMeta{}) } dm, err := disco.Parse(payload) @@ -2256,7 +1405,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke return } - isDERP := src.Addr() == derpMagicIPAddr + isDERP := src.Addr() == tailcfg.DerpMagicIPAddr if isDERP { metricRecvDiscoDERP.Add(1) } else { @@ -2354,7 +1503,7 @@ func (c *Conn) handlePingLocked(dm *disco.Ping, src netip.AddrPort, di *discoInf likelyHeartBeat := src == di.lastPingFrom && time.Since(di.lastPingTime) < 5*time.Second di.lastPingFrom = src di.lastPingTime = time.Now() - isDerp := src.Addr() == derpMagicIPAddr + isDerp := src.Addr() == tailcfg.DerpMagicIPAddr // If we can figure out with certainty which node key this disco // message is for, eagerly update our IP<>node and disco<>node @@ -2519,7 +1668,7 @@ func (c *Conn) SetPreferredPort(port uint16) { c.port.Store(uint32(port)) if err := c.rebind(dropCurrentPort); err != nil { - c.logf("%w", err) + c.logf("%v", err) return } c.resetEndpointStates() @@ -2582,7 +1731,7 @@ func (c *Conn) SetPrivateKey(privateKey key.NodePrivate) error { // then removes any state for old peers. // // The caller passes ownership of newPeers map to UpdatePeers. -func (c *Conn) UpdatePeers(newPeers map[key.NodePublic]struct{}) { +func (c *Conn) UpdatePeers(newPeers set.Set[key.NodePublic]) { c.mu.Lock() defer c.mu.Unlock() @@ -2592,7 +1741,7 @@ func (c *Conn) UpdatePeers(newPeers map[key.NodePublic]struct{}) { // Clean up any key.NodePublic-keyed maps for peers that no longer // exist. for peer := range oldPeers { - if _, ok := newPeers[peer]; !ok { + if !newPeers.Contains(peer) { delete(c.derpRoute, peer) delete(c.peerLastDerp, peer) } @@ -2603,81 +1752,53 @@ func (c *Conn) UpdatePeers(newPeers map[key.NodePublic]struct{}) { } } -// SetDERPMap controls which (if any) DERP servers are used. -// A nil value means to disable DERP; it's disabled by default. -func (c *Conn) SetDERPMap(dm *tailcfg.DERPMap) { - c.mu.Lock() - defer c.mu.Unlock() - - var derpAddr = debugUseDERPAddr() - if derpAddr != "" { - derpPort := 443 - if debugUseDERPHTTP() { - // Match the port for -dev in derper.go - derpPort = 3340 - } - dm = &tailcfg.DERPMap{ - OmitDefaultRegions: true, - Regions: map[int]*tailcfg.DERPRegion{ - 999: { - RegionID: 999, - Nodes: []*tailcfg.DERPNode{{ - Name: "999dev", - RegionID: 999, - HostName: derpAddr, - DERPPort: derpPort, - }}, - }, - }, +func nodesEqual(x, y views.Slice[tailcfg.NodeView]) bool { + if x.Len() != y.Len() { + return false + } + for i := range x.LenIter() { + if !x.At(i).Equal(y.At(i)) { + return false } } + return true +} - if reflect.DeepEqual(dm, c.derpMap) { - return +// debugRingBufferSize returns a maximum size for our set of endpoint ring +// buffers by assuming that a single large update is ~500 bytes, and that we +// want to not use more than 1MiB of memory on phones / 4MiB on other devices. +// Calculate the per-endpoint ring buffer size by dividing that out, but always +// storing at least two entries. +func debugRingBufferSize(numPeers int) int { + const defaultVal = 2 + if numPeers == 0 { + return defaultVal } - - c.derpMapAtomic.Store(dm) - old := c.derpMap - c.derpMap = dm - if dm == nil { - c.closeAllDerpLocked("derp-disabled") - return + var maxRingBufferSize int + if runtime.GOOS == "ios" || runtime.GOOS == "android" { + maxRingBufferSize = 1 * 1024 * 1024 + } else { + maxRingBufferSize = 4 * 1024 * 1024 } - - // Reconnect any DERP region that changed definitions. - if old != nil { - changes := false - for rid, oldDef := range old.Regions { - if reflect.DeepEqual(oldDef, dm.Regions[rid]) { - continue - } - changes = true - if rid == c.myDerp { - c.myDerp = 0 - } - c.closeDerpLocked(rid, "derp-region-redefined") - } - if changes { - c.logActiveDerpLocked() - } + if v := debugRingBufferMaxSizeBytes(); v > 0 { + maxRingBufferSize = v } - go c.ReSTUN("derp-map-update") + const averageRingBufferElemSize = 512 + return max(defaultVal, maxRingBufferSize/(averageRingBufferElemSize*numPeers)) } -func nodesEqual(x, y []*tailcfg.Node) bool { - if len(x) != len(y) { - return false - } - for i := range x { - if !x[i].Equal(y[i]) { - return false - } - } - return true +// debugFlags are the debug flags in use by the magicsock package. +// They might be set by envknob and/or controlknob. +// The value is comparable. +type debugFlags struct { + heartbeatDisabled bool } -var debugRingBufferMaxSizeBytes = envknob.RegisterInt("TS_DEBUG_MAGICSOCK_RING_BUFFER_MAX_SIZE_BYTES") +func (c *Conn) debugFlagsLocked() (f debugFlags) { + f.heartbeatDisabled = debugEnableSilentDisco() // TODO(bradfitz): controlknobs too, later + return +} // SetNetworkMap is called when the control client gets a new network // map from the control server. It must always be non-nil. @@ -2692,50 +1813,32 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) { return } - priorNetmap := c.netMap - var priorDebug *tailcfg.Debug - if priorNetmap != nil { - priorDebug = priorNetmap.Debug - } - debugChanged := !reflect.DeepEqual(priorDebug, nm.Debug) + priorPeers := c.peers metricNumPeers.Set(int64(len(nm.Peers))) // Update c.netMap regardless, before the following early return. - c.netMap = nm + curPeers := views.SliceOf(nm.Peers) + c.peers = curPeers - if priorNetmap != nil && nodesEqual(priorNetmap.Peers, nm.Peers) && !debugChanged { + flags := c.debugFlagsLocked() + if addrs := nm.GetAddresses(); addrs.Len() > 0 { + c.firstAddrForTest = addrs.At(0).Addr() + } else { + c.firstAddrForTest = netip.Addr{} + } + + if nodesEqual(priorPeers, curPeers) && c.lastFlags == flags { // The rest of this function is all adjusting state for peers that have // changed. But if the set of peers is equal and the debug flags (for // silent disco) haven't changed, no need to do anything else. return } + c.lastFlags = flags + c.logf("[v1] magicsock: got updated network map; %d peers", len(nm.Peers)) - heartbeatDisabled := debugEnableSilentDisco() || (c.netMap != nil && c.netMap.Debug != nil && c.netMap.Debug.EnableSilentDisco) - - // Set a maximum size for our set of endpoint ring buffers by assuming - // that a single large update is ~500 bytes, and that we want to not - // use more than 1MiB of memory on phones / 4MiB on other devices. - // Calculate the per-endpoint ring buffer size by dividing that out, - // but always storing at least two entries. - var entriesPerBuffer int = 2 - if len(nm.Peers) > 0 { - var maxRingBufferSize int - if runtime.GOOS == "ios" || runtime.GOOS == "android" { - maxRingBufferSize = 1 * 1024 * 1024 - } else { - maxRingBufferSize = 4 * 1024 * 1024 - } - if v := debugRingBufferMaxSizeBytes(); v > 0 { - maxRingBufferSize = v - } - const averageRingBufferElemSize = 512 - entriesPerBuffer = maxRingBufferSize / (averageRingBufferElemSize * len(nm.Peers)) - if entriesPerBuffer < 2 { - entriesPerBuffer = 2 - } - } + entriesPerBuffer := debugRingBufferSize(len(nm.Peers)) // Try a pass of just upserting nodes and creating missing // endpoints. If the set of nodes is the same, this is an @@ -2743,8 +1846,27 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) { // we'll fall through to the next pass, which allocates but can // handle full set updates. for _, n := range nm.Peers { - if ep, ok := c.peerMap.endpointForNodeKey(n.Key); ok { - if n.DiscoKey.IsZero() && !n.IsWireGuardOnly { + if n.ID() == 0 { + devPanicf("node with zero ID") + continue + } + if n.Key().IsZero() { + devPanicf("node with zero key") + continue + } + ep, ok := c.peerMap.endpointForNodeID(n.ID()) + if ok && ep.publicKey != n.Key() { + // The node rotated public keys. Delete the old endpoint and create + // it anew. + c.peerMap.deleteEndpoint(ep) + ok = false + } + if ok { + // At this point we're modifying an existing endpoint (ep) whose + // public key and nodeID match n. Its other fields (such as disco + // key or endpoints) might've changed. + + if n.DiscoKey().IsZero() && !n.IsWireGuardOnly() { // Discokey transitioned from non-zero to zero? This should not // happen in the wild, however it could mean: // 1. A node was downgraded from post 0.100 to pre 0.100. @@ -2759,64 +1881,60 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) { if epDisco := ep.disco.Load(); epDisco != nil { oldDiscoKey = epDisco.key } - ep.updateFromNode(n, heartbeatDisabled) + ep.updateFromNode(n, flags.heartbeatDisabled) c.peerMap.upsertEndpoint(ep, oldDiscoKey) // maybe update discokey mappings in peerMap continue } - if n.DiscoKey.IsZero() && !n.IsWireGuardOnly { - // Ancient pre-0.100 node, which does not have a disco key, and will only be reachable via DERP. + + if ep, ok := c.peerMap.endpointForNodeKey(n.Key()); ok { + // At this point n.Key() should be for a key we've never seen before. If + // ok was true above, it was an update to an existing matching key and + // we don't get this far. If ok was false above, that means it's a key + // that differs from the one the NodeID had. But double check. + if ep.nodeID != n.ID() { + // Server error. + devPanicf("public key moved between nodeIDs") + } else { + // Internal data structures out of sync. + devPanicf("public key found in peerMap but not by nodeID") + } + continue + } + if n.DiscoKey().IsZero() && !n.IsWireGuardOnly() { + // Ancient pre-0.100 node, which does not have a disco key. + // No longer supported. continue } - ep := &endpoint{ + ep = &endpoint{ c: c, debugUpdates: ringbuffer.New[EndpointChange](entriesPerBuffer), - publicKey: n.Key, - publicKeyHex: n.Key.UntypedHexString(), + nodeID: n.ID(), + publicKey: n.Key(), + publicKeyHex: n.Key().UntypedHexString(), sentPing: map[stun.TxID]sentPing{}, endpointState: map[netip.AddrPort]*endpointState{}, - heartbeatDisabled: heartbeatDisabled, - isWireguardOnly: n.IsWireGuardOnly, + heartbeatDisabled: flags.heartbeatDisabled, + isWireguardOnly: n.IsWireGuardOnly(), } - if len(n.Addresses) > 0 { - ep.nodeAddr = n.Addresses[0].Addr() + if n.Addresses().Len() > 0 { + ep.nodeAddr = n.Addresses().At(0).Addr() } ep.initFakeUDPAddr() - if n.DiscoKey.IsZero() { + if n.DiscoKey().IsZero() { ep.disco.Store(nil) } else { ep.disco.Store(&endpointDisco{ - key: n.DiscoKey, - short: n.DiscoKey.ShortString(), + key: n.DiscoKey(), + short: n.DiscoKey().ShortString(), }) + } - if debugDisco() { // rather than making a new knob - c.logf("magicsock: created endpoint key=%s: disco=%s; %v", n.Key.ShortString(), n.DiscoKey.ShortString(), logger.ArgWriter(func(w *bufio.Writer) { - const derpPrefix = "127.3.3.40:" - if strings.HasPrefix(n.DERP, derpPrefix) { - ipp, _ := netip.ParseAddrPort(n.DERP) - regionID := int(ipp.Port()) - code := c.derpRegionCodeLocked(regionID) - if code != "" { - code = "(" + code + ")" - } - fmt.Fprintf(w, "derp=%v%s ", regionID, code) - } - - for _, a := range n.AllowedIPs { - if a.IsSingleIP() { - fmt.Fprintf(w, "aip=%v ", a.Addr()) - } else { - fmt.Fprintf(w, "aip=%v ", a) - } - } - for _, ep := range n.Endpoints { - fmt.Fprintf(w, "ep=%v ", ep) - } - })) - } + if debugPeerMap() { + c.logEndpointCreated(n) } - ep.updateFromNode(n, heartbeatDisabled) + + ep.updateFromNode(n, flags.heartbeatDisabled) c.peerMap.upsertEndpoint(ep, key.DiscoPublic{}) } @@ -2826,12 +1944,12 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) { // current netmap. If that happens, go through the allocful // deletion path to clean up moribund nodes. if c.peerMap.nodeCount() != len(nm.Peers) { - keep := make(map[key.NodePublic]bool, len(nm.Peers)) + keep := set.Set[key.NodePublic]{} for _, n := range nm.Peers { - keep[n.Key] = true + keep.Add(n.Key()) } c.peerMap.forEachEndpoint(func(ep *endpoint) { - if !keep[ep.publicKey] { + if !keep.Contains(ep.publicKey) { c.peerMap.deleteEndpoint(ep) } }) @@ -2845,100 +1963,38 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) { } } -func (c *Conn) wantDerpLocked() bool { return c.derpMap != nil } - -// c.mu must be held. -func (c *Conn) closeAllDerpLocked(why string) { - if len(c.activeDerp) == 0 { - return // without the useless log statement - } - for i := range c.activeDerp { - c.closeDerpLocked(i, why) +func devPanicf(format string, a ...any) { + if testenv.InTest() || envknob.CrashOnUnexpected() { + panic(fmt.Sprintf(format, a...)) } - c.logActiveDerpLocked() } -// maybeCloseDERPsOnRebind, in response to a rebind, closes all -// DERP connections that don't have a local address in okayLocalIPs -// and pings all those that do. -func (c *Conn) maybeCloseDERPsOnRebind(okayLocalIPs []netip.Prefix) { - c.mu.Lock() - defer c.mu.Unlock() - for regionID, ad := range c.activeDerp { - la, err := ad.c.LocalAddr() - if err != nil { - c.closeOrReconnectDERPLocked(regionID, "rebind-no-localaddr") - continue - } - if !tsaddr.PrefixesContainsIP(okayLocalIPs, la.Addr()) { - c.closeOrReconnectDERPLocked(regionID, "rebind-default-route-change") - continue +func (c *Conn) logEndpointCreated(n tailcfg.NodeView) { + c.logf("magicsock: created endpoint key=%s: disco=%s; %v", n.Key().ShortString(), n.DiscoKey().ShortString(), logger.ArgWriter(func(w *bufio.Writer) { + const derpPrefix = "127.3.3.40:" + if strings.HasPrefix(n.DERP(), derpPrefix) { + ipp, _ := netip.ParseAddrPort(n.DERP()) + regionID := int(ipp.Port()) + code := c.derpRegionCodeLocked(regionID) + if code != "" { + code = "(" + code + ")" + } + fmt.Fprintf(w, "derp=%v%s ", regionID, code) } - regionID := regionID - dc := ad.c - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - if err := dc.Ping(ctx); err != nil { - c.mu.Lock() - defer c.mu.Unlock() - c.closeOrReconnectDERPLocked(regionID, "rebind-ping-fail") - return + + for i := range n.AllowedIPs().LenIter() { + a := n.AllowedIPs().At(i) + if a.IsSingleIP() { + fmt.Fprintf(w, "aip=%v ", a.Addr()) + } else { + fmt.Fprintf(w, "aip=%v ", a) } - c.logf("post-rebind ping of DERP region %d okay", regionID) - }() - } - c.logActiveDerpLocked() -} - -// closeOrReconnectDERPLocked closes the DERP connection to the -// provided regionID and starts reconnecting it if it's our current -// home DERP. -// -// why is a reason for logging. -// -// c.mu must be held. -func (c *Conn) closeOrReconnectDERPLocked(regionID int, why string) { - c.closeDerpLocked(regionID, why) - if !c.privateKey.IsZero() && c.myDerp == regionID { - c.startDerpHomeConnectLocked() - } -} - -// c.mu must be held. -// It is the responsibility of the caller to call logActiveDerpLocked after any set of closes. -func (c *Conn) closeDerpLocked(regionID int, why string) { - if ad, ok := c.activeDerp[regionID]; ok { - c.logf("magicsock: closing connection to derp-%v (%v), age %v", regionID, why, time.Since(ad.createTime).Round(time.Second)) - go ad.c.Close() - ad.cancel() - delete(c.activeDerp, regionID) - metricNumDERPConns.Set(int64(len(c.activeDerp))) - } -} - -// c.mu must be held. -func (c *Conn) logActiveDerpLocked() { - now := time.Now() - c.logf("magicsock: %v active derp conns%s", len(c.activeDerp), logger.ArgWriter(func(buf *bufio.Writer) { - if len(c.activeDerp) == 0 { - return - } - buf.WriteString(":") - c.foreachActiveDerpSortedLocked(func(node int, ad activeDerp) { - fmt.Fprintf(buf, " derp-%d=cr%v,wr%v", node, simpleDur(now.Sub(ad.createTime)), simpleDur(now.Sub(*ad.lastWrite))) - }) - })) -} - -// EndpointChange is a structure containing information about changes made to a -// particular endpoint. This is not a stable interface and could change at any -// time. -type EndpointChange struct { - When time.Time // when the change occurred - What string // what this change is - From any `json:",omitempty"` // information about the previous state - To any `json:",omitempty"` // information about the new state + } + for i := range n.Endpoints().LenIter() { + ep := n.Endpoints().At(i) + fmt.Fprintf(w, "ep=%v ", ep) + } + })) } func (c *Conn) logEndpointChange(endpoints []tailcfg.Endpoint) { @@ -2952,77 +2008,6 @@ func (c *Conn) logEndpointChange(endpoints []tailcfg.Endpoint) { })) } -// c.mu must be held. -func (c *Conn) foreachActiveDerpSortedLocked(fn func(regionID int, ad activeDerp)) { - if len(c.activeDerp) < 2 { - for id, ad := range c.activeDerp { - fn(id, ad) - } - return - } - ids := make([]int, 0, len(c.activeDerp)) - for id := range c.activeDerp { - ids = append(ids, id) - } - sort.Ints(ids) - for _, id := range ids { - fn(id, c.activeDerp[id]) - } -} - -func (c *Conn) cleanStaleDerp() { - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { - return - } - c.derpCleanupTimerArmed = false - - tooOld := time.Now().Add(-derpInactiveCleanupTime) - dirty := false - someNonHomeOpen := false - for i, ad := range c.activeDerp { - if i == c.myDerp { - continue - } - if ad.lastWrite.Before(tooOld) { - c.closeDerpLocked(i, "idle") - dirty = true - } else { - someNonHomeOpen = true - } - } - if dirty { - c.logActiveDerpLocked() - } - if someNonHomeOpen { - c.scheduleCleanStaleDerpLocked() - } -} - -func (c *Conn) scheduleCleanStaleDerpLocked() { - if c.derpCleanupTimerArmed { - // Already going to fire soon. Let the existing one - // fire lest it get infinitely delayed by repeated - // calls to scheduleCleanStaleDerpLocked. - return - } - c.derpCleanupTimerArmed = true - if c.derpCleanupTimer != nil { - c.derpCleanupTimer.Reset(derpCleanStaleInterval) - } else { - c.derpCleanupTimer = time.AfterFunc(derpCleanStaleInterval, c.cleanStaleDerp) - } -} - -// DERPs reports the number of active DERP connections. -func (c *Conn) DERPs() int { - c.mu.Lock() - defer c.mu.Unlock() - - return len(c.activeDerp) -} - // Bind returns the wireguard-go conn.Bind for c. // // See https://pkg.go.dev/golang.zx2c4.com/wireguard/conn#Bind @@ -3041,6 +2026,8 @@ type connBind struct { closed bool } +// This is a compile-time assertion that connBind implements the wireguard-go +// conn.Bind interface. var _ conn.Bind = (*connBind)(nil) // BatchSize returns the number of buffers expected to be passed to @@ -3184,13 +2171,6 @@ func (c *Conn) goroutinesRunningLocked() bool { return false } -func maxIdleBeforeSTUNShutdown() time.Duration { - if debugReSTUNStopOnIdle() { - return 45 * time.Second - } - return sessionActiveTimeout -} - func (c *Conn) shouldDoPeriodicReSTUNLocked() bool { if c.networkDown() { return false @@ -3205,8 +2185,8 @@ func (c *Conn) shouldDoPeriodicReSTUNLocked() bool { if debugReSTUNStopOnIdle() { c.logf("magicsock: periodicReSTUN: idle for %v", idleFor.Round(time.Second)) } - if idleFor > maxIdleBeforeSTUNShutdown() { - if c.netMap != nil && c.netMap.Debug != nil && c.netMap.Debug.ForceBackgroundSTUN { + if idleFor > sessionActiveTimeout { + if c.controlKnobs != nil && c.controlKnobs.ForceBackgroundSTUN.Load() { // Overridden by control. return true } @@ -3271,8 +2251,6 @@ func (c *Conn) listenPacket(network string, port uint16) (nettype.PacketConn, er return nettype.MakePacketListenerWithNetIP(netns.Listener(c.logf, c.netMon)).ListenPacket(ctx, network, addr) } -var debugBindSocket = envknob.RegisterBool("TS_DEBUG_MAGICSOCK_BIND_SOCKET") - // bindSocket initializes rucPtr if necessary and binds a UDP socket to it. // Network indicates the UDP socket type; it must be "udp4" or "udp6". // If rucPtr had an existing UDP socket bound, it closes that socket. @@ -3334,6 +2312,7 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur continue } trySetSocketBuffer(pconn, c.logf) + // Success. if debugBindSocket() { c.logf("magicsock: bindSocket: successfully listened %v port %d", network, port) @@ -3373,6 +2352,7 @@ func (c *Conn) rebind(curPortFate currentPortFate) error { return fmt.Errorf("magicsock: Rebind IPv4 failed: %w", err) } c.portMapper.SetLocalPort(c.LocalPort()) + c.UpdatePMTUD() return nil } @@ -3381,7 +2361,7 @@ func (c *Conn) rebind(curPortFate currentPortFate) error { func (c *Conn) Rebind() { metricRebindCalls.Add(1) if err := c.rebind(keepCurrentPort); err != nil { - c.logf("%w", err) + c.logf("%v", err) return } @@ -3448,183 +2428,6 @@ func (c *Conn) ParseEndpoint(nodeKeyStr string) (conn.Endpoint, error) { return ep, nil } -// xnetBatchReaderWriter defines the batching i/o methods of -// golang.org/x/net/ipv4.PacketConn (and ipv6.PacketConn). -// TODO(jwhited): This should eventually be replaced with the standard library -// implementation of https://github.com/golang/go/issues/45886 -type xnetBatchReaderWriter interface { - xnetBatchReader - xnetBatchWriter -} - -type xnetBatchReader interface { - ReadBatch([]ipv6.Message, int) (int, error) -} - -type xnetBatchWriter interface { - WriteBatch([]ipv6.Message, int) (int, error) -} - -// batchingUDPConn is a UDP socket that provides batched i/o. -type batchingUDPConn struct { - pc nettype.PacketConn - xpc xnetBatchReaderWriter - rxOffload bool // supports UDP GRO or similar - txOffload atomic.Bool // supports UDP GSO or similar - setGSOSizeInControl func(control *[]byte, gsoSize uint16) // typically setGSOSizeInControl(); swappable for testing - getGSOSizeFromControl func(control []byte) (int, error) // typically getGSOSizeFromControl(); swappable for testing - sendBatchPool sync.Pool -} - -func (c *batchingUDPConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) { - if c.rxOffload { - // UDP_GRO is opt-in on Linux via setsockopt(). Once enabled you may - // receive a "monster datagram" from any read call. The ReadFrom() API - // does not support passing the GSO size and is unsafe to use in such a - // case. Other platforms may vary in behavior, but we go with the most - // conservative approach to prevent this from becoming a footgun in the - // future. - return 0, netip.AddrPort{}, errors.New("rx UDP offload is enabled on this socket, single packet reads are unavailable") - } - return c.pc.ReadFromUDPAddrPort(p) -} - -func (c *batchingUDPConn) SetDeadline(t time.Time) error { - return c.pc.SetDeadline(t) -} - -func (c *batchingUDPConn) SetReadDeadline(t time.Time) error { - return c.pc.SetReadDeadline(t) -} - -func (c *batchingUDPConn) SetWriteDeadline(t time.Time) error { - return c.pc.SetWriteDeadline(t) -} - -const ( - // This was initially established for Linux, but may split out to - // GOOS-specific values later. It originates as UDP_MAX_SEGMENTS in the - // kernel's TX path, and UDP_GRO_CNT_MAX for RX. - udpSegmentMaxDatagrams = 64 -) - -const ( - // Exceeding these values results in EMSGSIZE. - maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 - maxIPv6PayloadLen = 1<<16 - 1 - 8 -) - -// coalesceMessages iterates msgs, coalescing them where possible while -// maintaining datagram order. All msgs have their Addr field set to addr. -func (c *batchingUDPConn) coalesceMessages(addr *net.UDPAddr, buffs [][]byte, msgs []ipv6.Message) int { - var ( - base = -1 // index of msg we are currently coalescing into - gsoSize int // segmentation size of msgs[base] - dgramCnt int // number of dgrams coalesced into msgs[base] - endBatch bool // tracking flag to start a new batch on next iteration of buffs - ) - maxPayloadLen := maxIPv4PayloadLen - if addr.IP.To4() == nil { - maxPayloadLen = maxIPv6PayloadLen - } - for i, buff := range buffs { - if i > 0 { - msgLen := len(buff) - baseLenBefore := len(msgs[base].Buffers[0]) - freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore - if msgLen+baseLenBefore <= maxPayloadLen && - msgLen <= gsoSize && - msgLen <= freeBaseCap && - dgramCnt < udpSegmentMaxDatagrams && - !endBatch { - msgs[base].Buffers[0] = append(msgs[base].Buffers[0], make([]byte, msgLen)...) - copy(msgs[base].Buffers[0][baseLenBefore:], buff) - if i == len(buffs)-1 { - c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize)) - } - dgramCnt++ - if msgLen < gsoSize { - // A smaller than gsoSize packet on the tail is legal, but - // it must end the batch. - endBatch = true - } - continue - } - } - if dgramCnt > 1 { - c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize)) - } - // Reset prior to incrementing base since we are preparing to start a - // new potential batch. - endBatch = false - base++ - gsoSize = len(buff) - msgs[base].OOB = msgs[base].OOB[:0] - msgs[base].Buffers[0] = buff - msgs[base].Addr = addr - dgramCnt = 1 - } - return base + 1 -} - -type sendBatch struct { - msgs []ipv6.Message - ua *net.UDPAddr -} - -func (c *batchingUDPConn) getSendBatch() *sendBatch { - batch := c.sendBatchPool.Get().(*sendBatch) - return batch -} - -func (c *batchingUDPConn) putSendBatch(batch *sendBatch) { - for i := range batch.msgs { - batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers, OOB: batch.msgs[i].OOB} - } - c.sendBatchPool.Put(batch) -} - -func (c *batchingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error { - batch := c.getSendBatch() - defer c.putSendBatch(batch) - if addr.Addr().Is6() { - as16 := addr.Addr().As16() - copy(batch.ua.IP, as16[:]) - batch.ua.IP = batch.ua.IP[:16] - } else { - as4 := addr.Addr().As4() - copy(batch.ua.IP, as4[:]) - batch.ua.IP = batch.ua.IP[:4] - } - batch.ua.Port = int(addr.Port()) - var ( - n int - retried bool - ) -retry: - if c.txOffload.Load() { - n = c.coalesceMessages(batch.ua, buffs, batch.msgs) - } else { - for i := range buffs { - batch.msgs[i].Buffers[0] = buffs[i] - batch.msgs[i].Addr = batch.ua - batch.msgs[i].OOB = batch.msgs[i].OOB[:0] - } - n = len(buffs) - } - - err := c.writeBatch(batch.msgs[:n]) - if err != nil && c.txOffload.Load() && neterror.ShouldDisableUDPGSO(err) { - c.txOffload.Store(false) - retried = true - goto retry - } - if retried { - return neterror.ErrUDPGSODisabled{OnLaddr: c.pc.LocalAddr().String(), RetryErr: err} - } - return err -} - func (c *batchingUDPConn) writeBatch(msgs []ipv6.Message) error { var head int for { @@ -3772,205 +2575,12 @@ func tryUpgradeToBatchingUDPConn(pconn nettype.PacketConn, network string, batch return b } -// RebindingUDPConn is a UDP socket that can be re-bound. -// Unix has no notion of re-binding a socket, so we swap it out for a new one. -type RebindingUDPConn struct { - // pconnAtomic is a pointer to the value stored in pconn, but doesn't - // require acquiring mu. It's used for reads/writes and only upon failure - // do the reads/writes then check pconn (after acquiring mu) to see if - // there's been a rebind meanwhile. - // pconn isn't really needed, but makes some of the code simpler - // to keep it distinct. - // Neither is expected to be nil, sockets are bound on creation. - pconnAtomic atomic.Pointer[nettype.PacketConn] - - mu sync.Mutex // held while changing pconn (and pconnAtomic) - pconn nettype.PacketConn - port uint16 -} - -// setConnLocked sets the provided nettype.PacketConn. It should be called only -// after acquiring RebindingUDPConn.mu. It upgrades the provided -// nettype.PacketConn to a *batchingUDPConn when appropriate. This upgrade -// is intentionally pushed closest to where read/write ops occur in order to -// avoid disrupting surrounding code that assumes nettype.PacketConn is a -// *net.UDPConn. -func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn, network string, batchSize int) { - upc := tryUpgradeToBatchingUDPConn(p, network, batchSize) - c.pconn = upc - c.pconnAtomic.Store(&upc) - c.port = uint16(c.localAddrLocked().Port) -} - -// currentConn returns c's current pconn, acquiring c.mu in the process. -func (c *RebindingUDPConn) currentConn() nettype.PacketConn { - c.mu.Lock() - defer c.mu.Unlock() - return c.pconn -} - -func (c *RebindingUDPConn) readFromWithInitPconn(pconn nettype.PacketConn, b []byte) (int, netip.AddrPort, error) { - for { - n, addr, err := pconn.ReadFromUDPAddrPort(b) - if err != nil && pconn != c.currentConn() { - pconn = *c.pconnAtomic.Load() - continue - } - return n, addr, err - } -} - -// ReadFromUDPAddrPort reads a packet from c into b. -// It returns the number of bytes copied and the source address. -func (c *RebindingUDPConn) ReadFromUDPAddrPort(b []byte) (int, netip.AddrPort, error) { - return c.readFromWithInitPconn(*c.pconnAtomic.Load(), b) -} - -// WriteBatchTo writes buffs to addr. -func (c *RebindingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error { - for { - pconn := *c.pconnAtomic.Load() - b, ok := pconn.(*batchingUDPConn) - if !ok { - for _, buf := range buffs { - _, err := c.writeToUDPAddrPortWithInitPconn(pconn, buf, addr) - if err != nil { - return err - } - } - return nil - } - err := b.WriteBatchTo(buffs, addr) - if err != nil { - if pconn != c.currentConn() { - continue - } - return err - } - return err - } -} - -// ReadBatch reads messages from c into msgs. It returns the number of messages -// the caller should evaluate for nonzero len, as a zero len message may fall -// on either side of a nonzero. -func (c *RebindingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (int, error) { - for { - pconn := *c.pconnAtomic.Load() - b, ok := pconn.(*batchingUDPConn) - if !ok { - n, ap, err := c.readFromWithInitPconn(pconn, msgs[0].Buffers[0]) - if err == nil { - msgs[0].N = n - msgs[0].Addr = net.UDPAddrFromAddrPort(netaddr.Unmap(ap)) - return 1, nil - } - return 0, err - } - n, err := b.ReadBatch(msgs, flags) - if err != nil && pconn != c.currentConn() { - continue - } - return n, err - } -} - -func (c *RebindingUDPConn) Port() uint16 { - c.mu.Lock() - defer c.mu.Unlock() - return c.port -} - -func (c *RebindingUDPConn) LocalAddr() *net.UDPAddr { - c.mu.Lock() - defer c.mu.Unlock() - return c.localAddrLocked() -} - -func (c *RebindingUDPConn) localAddrLocked() *net.UDPAddr { - return c.pconn.LocalAddr().(*net.UDPAddr) -} - -// errNilPConn is returned by RebindingUDPConn.Close when there is no current pconn. -// It is for internal use only and should not be returned to users. -var errNilPConn = errors.New("nil pconn") - -func (c *RebindingUDPConn) Close() error { - c.mu.Lock() - defer c.mu.Unlock() - return c.closeLocked() -} - -func (c *RebindingUDPConn) closeLocked() error { - if c.pconn == nil { - return errNilPConn - } - c.port = 0 - return c.pconn.Close() -} - -func (c *RebindingUDPConn) writeToUDPAddrPortWithInitPconn(pconn nettype.PacketConn, b []byte, addr netip.AddrPort) (int, error) { - for { - n, err := pconn.WriteToUDPAddrPort(b, addr) - if err != nil && pconn != c.currentConn() { - pconn = *c.pconnAtomic.Load() - continue - } - return n, err - } -} - -func (c *RebindingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { - return c.writeToUDPAddrPortWithInitPconn(*c.pconnAtomic.Load(), b, addr) -} - func newBlockForeverConn() *blockForeverConn { c := new(blockForeverConn) c.cond = sync.NewCond(&c.mu) return c } -// blockForeverConn is a net.PacketConn whose reads block until it is closed. -type blockForeverConn struct { - mu sync.Mutex - cond *sync.Cond - closed bool -} - -func (c *blockForeverConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) { - c.mu.Lock() - for !c.closed { - c.cond.Wait() - } - c.mu.Unlock() - return 0, netip.AddrPort{}, net.ErrClosed -} - -func (c *blockForeverConn) WriteToUDPAddrPort(p []byte, addr netip.AddrPort) (int, error) { - // Silently drop writes. - return len(p), nil -} - -func (c *blockForeverConn) LocalAddr() net.Addr { - // Return a *net.UDPAddr because lots of code assumes that it will. - return new(net.UDPAddr) -} - -func (c *blockForeverConn) Close() error { - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { - return net.ErrClosed - } - c.closed = true - c.cond.Broadcast() - return nil -} - -func (c *blockForeverConn) SetDeadline(t time.Time) error { return errors.New("unimplemented") } -func (c *blockForeverConn) SetReadDeadline(t time.Time) error { return errors.New("unimplemented") } -func (c *blockForeverConn) SetWriteDeadline(t time.Time) error { return errors.New("unimplemented") } - // simpleDur rounds d such that it stringifies to something short. func simpleDur(d time.Duration) time.Duration { if d < time.Second { @@ -3982,88 +2592,61 @@ func simpleDur(d time.Duration) time.Duration { return d.Round(time.Minute) } -func sbPrintAddr(sb *strings.Builder, a netip.AddrPort) { - is6 := a.Addr().Is6() - if is6 { - sb.WriteByte('[') - } - fmt.Fprintf(sb, "%s", a.Addr()) - if is6 { - sb.WriteByte(']') - } - fmt.Fprintf(sb, ":%d", a.Port()) -} - -func (c *Conn) derpRegionCodeOfAddrLocked(ipPort string) string { - _, portStr, err := net.SplitHostPort(ipPort) - if err != nil { - return "" - } - regionID, err := strconv.Atoi(portStr) - if err != nil { - return "" - } - return c.derpRegionCodeOfIDLocked(regionID) -} +// UpdateNetmapDelta implements controlclient.NetmapDeltaUpdater. +func (c *Conn) UpdateNetmapDelta(muts []netmap.NodeMutation) (handled bool) { + c.mu.Lock() + defer c.mu.Unlock() -func (c *Conn) derpRegionCodeOfIDLocked(regionID int) string { - if c.derpMap == nil { - return "" - } - if r, ok := c.derpMap.Regions[regionID]; ok { - return r.RegionCode + for _, m := range muts { + nodeID := m.NodeIDBeingMutated() + ep, ok := c.peerMap.endpointForNodeID(nodeID) + if !ok { + continue + } + switch m := m.(type) { + case netmap.NodeMutationDERPHome: + ep.setDERPHome(uint16(m.DERPRegion)) + case netmap.NodeMutationEndpoints: + ep.mu.Lock() + ep.setEndpointsLocked(views.SliceOf(m.Endpoints)) + ep.mu.Unlock() + } } - return "" + return true } +// UpdateStatus implements the interface nede by ipnstate.StatusBuilder. +// +// This method adds in the magicsock-specific information only. Most +// of the status is otherwise populated by LocalBackend. func (c *Conn) UpdateStatus(sb *ipnstate.StatusBuilder) { c.mu.Lock() defer c.mu.Unlock() - var tailscaleIPs []netip.Addr - if c.netMap != nil { - tailscaleIPs = make([]netip.Addr, 0, len(c.netMap.Addresses)) - for _, addr := range c.netMap.Addresses { - if !addr.IsSingleIP() { - continue - } - sb.AddTailscaleIP(addr.Addr()) - tailscaleIPs = append(tailscaleIPs, addr.Addr()) - } - } - sb.MutateSelfStatus(func(ss *ipnstate.PeerStatus) { - if !c.privateKey.IsZero() { - ss.PublicKey = c.privateKey.Public() - } else { - ss.PublicKey = key.NodePublic{} - } ss.Addrs = make([]string, 0, len(c.lastEndpoints)) for _, ep := range c.lastEndpoints { ss.Addrs = append(ss.Addrs, ep.Addr.String()) } - ss.OS = version.OS() if c.derpMap != nil { - derpRegion, ok := c.derpMap.Regions[c.myDerp] - if ok { - ss.Relay = derpRegion.RegionCode + if reg, ok := c.derpMap.Regions[c.myDerp]; ok { + ss.Relay = reg.RegionCode } } - ss.TailscaleIPs = tailscaleIPs }) if sb.WantPeers { c.peerMap.forEachEndpoint(func(ep *endpoint) { ps := &ipnstate.PeerStatus{InMagicSock: true} - //ps.Addrs = append(ps.Addrs, n.Endpoints...) ep.populatePeerStatus(ps) sb.AddPeer(ep.publicKey, ps) }) } c.foreachActiveDerpSortedLocked(func(node int, ad activeDerp) { - // TODO(bradfitz): add to ipnstate.StatusBuilder - //f("

  • derp-%v: cr%v,wr%v
  • ", node, simpleDur(now.Sub(ad.createTime)), simpleDur(now.Sub(*ad.lastWrite))) + // TODO(bradfitz): add a method to ipnstate.StatusBuilder + // to include all the DERP connections we have open + // and add it here. See the other caller of foreachActiveDerpSortedLocked. }) } @@ -4073,89 +2656,6 @@ func (c *Conn) SetStatistics(stats *connstats.Statistics) { c.stats.Store(stats) } -func ippDebugString(ua netip.AddrPort) string { - if ua.Addr() == derpMagicIPAddr { - return fmt.Sprintf("derp-%d", ua.Port()) - } - return ua.String() -} - -// endpointSendFunc is a func that writes encrypted Wireguard payloads from -// WireGuard to a peer. It might write via UDP, DERP, both, or neither. -// -// What these funcs should NOT do is too much work. Minimize use of mutexes, map -// lookups, etc. The idea is that selecting the path to use is done infrequently -// and mostly async from sending packets. When conditions change (including the -// passing of time and loss of confidence in certain routes), then a new send -// func gets set on an sendpoint. -// -// A nil value means the current fast path has expired and needs to be -// recalculated. -type endpointSendFunc func([][]byte) error - -// endpointDisco is the current disco key and short string for an endpoint. This -// structure is immutable. -type endpointDisco struct { - key key.DiscoPublic // for discovery messages. - short string // ShortString of discoKey. -} - -// endpoint is a wireguard/conn.Endpoint. In wireguard-go and kernel WireGuard -// there is only one endpoint for a peer, but in Tailscale we distribute a -// number of possible endpoints for a peer which would include the all the -// likely addresses at which a peer may be reachable. This endpoint type holds -// the information required that when WiregGuard-Go wants to send to a -// particular peer (essentally represented by this endpoint type), the send -// function can use the currnetly best known Tailscale endpoint to send packets -// to the peer. -type endpoint struct { - // atomically accessed; declared first for alignment reasons - lastRecv mono.Time - numStopAndResetAtomic int64 - sendFunc syncs.AtomicValue[endpointSendFunc] // nil or unset means unused - debugUpdates *ringbuffer.RingBuffer[EndpointChange] - - // These fields are initialized once and never modified. - c *Conn - publicKey key.NodePublic // peer public key (for WireGuard + DERP) - publicKeyHex string // cached output of publicKey.UntypedHexString - fakeWGAddr netip.AddrPort // the UDP address we tell wireguard-go we're using - nodeAddr netip.Addr // the node's first tailscale address; used for logging & wireguard rate-limiting (Issue 6686) - - disco atomic.Pointer[endpointDisco] // if the peer supports disco, the key and short string - - // mu protects all following fields. - mu sync.Mutex // Lock ordering: Conn.mu, then endpoint.mu - - heartBeatTimer *time.Timer // nil when idle - lastSend mono.Time // last time there was outgoing packets sent to this peer (from wireguard-go) - lastFullPing mono.Time // last time we pinged all disco endpoints - derpAddr netip.AddrPort // fallback/bootstrap path, if non-zero (non-zero for well-behaved clients) - - bestAddr addrLatency // best non-DERP path; zero if none - bestAddrAt mono.Time // time best address re-confirmed - trustBestAddrUntil mono.Time // time when bestAddr expires - sentPing map[stun.TxID]sentPing - endpointState map[netip.AddrPort]*endpointState - isCallMeMaybeEP map[netip.AddrPort]bool - - pendingCLIPings []pendingCLIPing // any outstanding "tailscale ping" commands running - - // The following fields are related to the new "silent disco" - // implementation that's a WIP as of 2022-10-20. - // See #540 for background. - heartbeatDisabled bool - pathFinderRunning bool - - expired bool // whether the node has expired - isWireguardOnly bool // whether the endpoint is WireGuard only -} - -type pendingCLIPing struct { - res *ipnstate.PingResult - cb func(*ipnstate.PingResult) -} - const ( // sessionActiveTimeout is how long since the last activity we // try to keep an established endpoint peering alive. @@ -4179,23 +2679,10 @@ const ( // try to upgrade to a better path. goodEnoughLatency = 5 * time.Millisecond - // derpInactiveCleanupTime is how long a non-home DERP connection - // needs to be idle (last written to) before we close it. - derpInactiveCleanupTime = 60 * time.Second - - // derpCleanStaleInterval is how often cleanStaleDerp runs when there - // are potentially-stale DERP connections to close. - derpCleanStaleInterval = 15 * time.Second - // endpointsFreshEnoughDuration is how long we consider a // STUN-derived endpoint valid for. UDP NAT mappings typically // expire at 30 seconds, so this is a few seconds shy of that. endpointsFreshEnoughDuration = 27 * time.Second - - // endpointTrackerLifetime is how long we continue advertising an - // endpoint after we last see it. This is intentionally chosen to be - // slightly longer than a full netcheck period. - endpointTrackerLifetime = 5*time.Minute + 10*time.Second ) // Constants that are variable for testing. @@ -4209,850 +2696,23 @@ var ( // resetting the counter, as the first pings likely didn't through // the firewall) discoPingInterval = 5 * time.Second -) - -// endpointState is some state and history for a specific endpoint of -// a endpoint. (The subject is the endpoint.endpointState -// map key) -type endpointState struct { - // all fields guarded by endpoint.mu - - // lastPing is the last (outgoing) ping time. - lastPing mono.Time - - // lastGotPing, if non-zero, means that this was an endpoint - // that we learned about at runtime (from an incoming ping) - // and that is not in the network map. If so, we keep the time - // updated and use it to discard old candidates. - lastGotPing time.Time - - // lastGotPingTxID contains the TxID for the last incoming ping. This is - // used to de-dup incoming pings that we may see on both the raw disco - // socket on Linux, and UDP socket. We cannot rely solely on the raw socket - // disco handling due to https://github.com/tailscale/tailscale/issues/7078. - lastGotPingTxID stun.TxID - - // callMeMaybeTime, if non-zero, is the time this endpoint - // was advertised last via a call-me-maybe disco message. - callMeMaybeTime time.Time - - recentPongs []pongReply // ring buffer up to pongHistoryCount entries - recentPong uint16 // index into recentPongs of most recent; older before, wrapped - index int16 // index in nodecfg.Node.Endpoints; meaningless if lastGotPing non-zero -} + // wireguardPingInterval is the minimum time between pings to an endpoint. + // Pings are only sent if we have not observed bidirectional traffic with an + // endpoint in at least this duration. + wireguardPingInterval = 5 * time.Second +) // indexSentinelDeleted is the temporary value that endpointState.index takes while // a endpoint's endpoints are being updated from a new network map. const indexSentinelDeleted = -1 -// shouldDeleteLocked reports whether we should delete this endpoint. -func (st *endpointState) shouldDeleteLocked() bool { - switch { - case !st.callMeMaybeTime.IsZero(): - return false - case st.lastGotPing.IsZero(): - // This was an endpoint from the network map. Is it still in the network map? - return st.index == indexSentinelDeleted - default: - // This was an endpoint discovered at runtime. - return time.Since(st.lastGotPing) > sessionActiveTimeout - } -} - -// latencyLocked returns the most recent latency measurement, if any. -// endpoint.mu must be held. -func (st *endpointState) latencyLocked() (lat time.Duration, ok bool) { - if len(st.recentPongs) == 0 { - return 0, false - } - return st.recentPongs[st.recentPong].latency, true -} - -func (de *endpoint) deleteEndpointLocked(why string, ep netip.AddrPort) { - de.debugUpdates.Add(EndpointChange{ - When: time.Now(), - What: "deleteEndpointLocked-" + why, - From: ep, - }) - delete(de.endpointState, ep) - if de.bestAddr.AddrPort == ep { - de.debugUpdates.Add(EndpointChange{ - When: time.Now(), - What: "deleteEndpointLocked-bestAddr-" + why, - From: de.bestAddr, - }) - de.bestAddr = addrLatency{} - } -} - -// pongHistoryCount is how many pongReply values we keep per endpointState -const pongHistoryCount = 64 - -type pongReply struct { - latency time.Duration - pongAt mono.Time // when we received the pong - from netip.AddrPort // the pong's src (usually same as endpoint map key) - pongSrc netip.AddrPort // what they reported they heard -} - -type sentPing struct { - to netip.AddrPort - at mono.Time - timer *time.Timer // timeout timer - purpose discoPingPurpose -} - -// initFakeUDPAddr populates fakeWGAddr with a globally unique fake UDPAddr. -// The current implementation just uses the pointer value of de jammed into an IPv6 -// address, but it could also be, say, a counter. -func (de *endpoint) initFakeUDPAddr() { - var addr [16]byte - addr[0] = 0xfd - addr[1] = 0x00 - binary.BigEndian.PutUint64(addr[2:], uint64(reflect.ValueOf(de).Pointer())) - de.fakeWGAddr = netip.AddrPortFrom(netip.AddrFrom16(addr).Unmap(), 12345) -} - -// noteRecvActivity records receive activity on de, and invokes -// Conn.noteRecvActivity no more than once every 10s. -func (de *endpoint) noteRecvActivity() { - if de.c.noteRecvActivity == nil { - return - } - now := mono.Now() - elapsed := now.Sub(de.lastRecv.LoadAtomic()) - if elapsed > 10*time.Second { - de.lastRecv.StoreAtomic(now) - de.c.noteRecvActivity(de.publicKey) - } -} - -func (de *endpoint) discoShort() string { - var short string - if d := de.disco.Load(); d != nil { - short = d.short - } - return short -} - -// String exists purely so wireguard-go internals can log.Printf("%v") -// its internal conn.Endpoints and we don't end up with data races -// from fmt (via log) reading mutex fields and such. -func (de *endpoint) String() string { - return fmt.Sprintf("magicsock.endpoint{%v, %v}", de.publicKey.ShortString(), de.discoShort()) -} - -func (de *endpoint) ClearSrc() {} -func (de *endpoint) SrcToString() string { panic("unused") } // unused by wireguard-go -func (de *endpoint) SrcIP() netip.Addr { panic("unused") } // unused by wireguard-go -func (de *endpoint) DstToString() string { return de.publicKeyHex } -func (de *endpoint) DstIP() netip.Addr { return de.nodeAddr } // see tailscale/tailscale#6686 -func (de *endpoint) DstToBytes() []byte { return packIPPort(de.fakeWGAddr) } - -// addrForSendLocked returns the address(es) that should be used for -// sending the next packet. Zero, one, or both of UDP address and DERP -// addr may be non-zero. If the endpoint is WireGuard only and does not have -// latency information, a bool is returned to indiciate that the -// WireGuard latency discovery pings should be sent. -// -// de.mu must be held. -func (de *endpoint) addrForSendLocked(now mono.Time) (udpAddr, derpAddr netip.AddrPort, sendWGPing bool) { - udpAddr = de.bestAddr.AddrPort - - if udpAddr.IsValid() && !now.After(de.trustBestAddrUntil) { - return udpAddr, netip.AddrPort{}, false - } - - if de.isWireguardOnly { - // If the endpoint is wireguard-only, we don't have a DERP - // address to send to, so we have to send to the UDP address. - udpAddr, shouldPing := de.addrForWireGuardSendLocked(now) - return udpAddr, netip.AddrPort{}, shouldPing - } - - // We had a bestAddr but it expired so send both to it - // and DERP. - return udpAddr, de.derpAddr, false -} - -// addrForWireGuardSendLocked returns the address that should be used for -// sending the next packet. If a packet has never or not recently been sent to -// the endpoint, then a randomly selected address for the endpoint is returned, -// as well as a bool indiciating that WireGuard discovery pings should be started. -// If the addresses have latency information available, then the address with the -// best latency is used. -// -// de.mu must be held. -func (de *endpoint) addrForWireGuardSendLocked(now mono.Time) (udpAddr netip.AddrPort, shouldPing bool) { - // lowestLatency is a high duration initially, so we - // can be sure we're going to have a duration lower than this - // for the first latency retrieved. - lowestLatency := time.Hour - for ipp, state := range de.endpointState { - if latency, ok := state.latencyLocked(); ok { - if latency < lowestLatency || latency == lowestLatency && ipp.Addr().Is6() { - // If we have the same latency,IPv6 is prioritized. - // TODO(catzkorn): Consider a small increase in latency to use - // IPv6 in comparison to IPv4, when possible. - lowestLatency = latency - udpAddr = ipp - } - } - } - - if udpAddr.IsValid() { - // Set trustBestAddrUntil to an hour, so we will - // continue to use this address for a long period of time. - de.bestAddr.AddrPort = udpAddr - de.trustBestAddrUntil = now.Add(1 * time.Hour) - return udpAddr, false - } - - candidates := make([]netip.AddrPort, 0, len(de.endpointState)) - for ipp := range de.endpointState { - if ipp.Addr().Is4() && de.c.noV4.Load() { - continue - } - if ipp.Addr().Is6() && de.c.noV6.Load() { - continue - } - candidates = append(candidates, ipp) - } - // Randomly select an address to use until we retrieve latency information - // and give it a short trustBestAddrUntil time so we avoid flapping between - // addresses while waiting on latency information to be populated. - udpAddr = candidates[rand.Intn(len(candidates))] - de.bestAddr.AddrPort = udpAddr - if len(candidates) == 1 { - // if we only have one address that we can send data too, - // we should trust it for a longer period of time. - de.trustBestAddrUntil = now.Add(1 * time.Hour) - } else { - de.trustBestAddrUntil = now.Add(15 * time.Second) - } - - return udpAddr, len(candidates) > 1 -} - -// heartbeat is called every heartbeatInterval to keep the best UDP path alive, -// or kick off discovery of other paths. -func (de *endpoint) heartbeat() { - de.mu.Lock() - defer de.mu.Unlock() - - de.heartBeatTimer = nil - - if de.heartbeatDisabled { - // If control override to disable heartBeatTimer set, return early. - return - } - - if de.lastSend.IsZero() { - // Shouldn't happen. - return - } - - if mono.Since(de.lastSend) > sessionActiveTimeout { - // Session's idle. Stop heartbeating. - de.c.dlogf("[v1] magicsock: disco: ending heartbeats for idle session to %v (%v)", de.publicKey.ShortString(), de.discoShort()) - return - } - - now := mono.Now() - udpAddr, _, _ := de.addrForSendLocked(now) - if udpAddr.IsValid() { - // We have a preferred path. Ping that every 2 seconds. - de.startDiscoPingLocked(udpAddr, now, pingHeartbeat) - } - - if de.wantFullPingLocked(now) { - de.sendDiscoPingsLocked(now, true) - } - - de.heartBeatTimer = time.AfterFunc(heartbeatInterval, de.heartbeat) -} - -// wantFullPingLocked reports whether we should ping to all our peers looking for -// a better path. -// -// de.mu must be held. -func (de *endpoint) wantFullPingLocked(now mono.Time) bool { - if runtime.GOOS == "js" { - return false - } - if !de.bestAddr.IsValid() || de.lastFullPing.IsZero() { - return true - } - if now.After(de.trustBestAddrUntil) { - return true - } - if de.bestAddr.latency <= goodEnoughLatency { - return false - } - if now.Sub(de.lastFullPing) >= upgradeInterval { - return true - } - return false -} - -func (de *endpoint) noteActiveLocked() { - de.lastSend = mono.Now() - if de.heartBeatTimer == nil && !de.heartbeatDisabled { - de.heartBeatTimer = time.AfterFunc(heartbeatInterval, de.heartbeat) - } -} - -// cliPing starts a ping for the "tailscale ping" command. res is value to call cb with, -// already partially filled. -func (de *endpoint) cliPing(res *ipnstate.PingResult, cb func(*ipnstate.PingResult)) { - de.mu.Lock() - defer de.mu.Unlock() - - if de.expired { - res.Err = errExpired.Error() - cb(res) - return - } - - de.pendingCLIPings = append(de.pendingCLIPings, pendingCLIPing{res, cb}) - - now := mono.Now() - udpAddr, derpAddr, _ := de.addrForSendLocked(now) - if derpAddr.IsValid() { - de.startDiscoPingLocked(derpAddr, now, pingCLI) - } - if udpAddr.IsValid() && now.Before(de.trustBestAddrUntil) { - // Already have an active session, so just ping the address we're using. - // Otherwise "tailscale ping" results to a node on the local network - // can look like they're bouncing between, say 10.0.0.0/9 and the peer's - // IPv6 address, both 1ms away, and it's random who replies first. - de.startDiscoPingLocked(udpAddr, now, pingCLI) - } else { - for ep := range de.endpointState { - de.startDiscoPingLocked(ep, now, pingCLI) - } - } - de.noteActiveLocked() -} - -var ( - errExpired = errors.New("peer's node key has expired") - errNoUDPOrDERP = errors.New("no UDP or DERP addr") -) - -func (de *endpoint) send(buffs [][]byte) error { - if fn := de.sendFunc.Load(); fn != nil { - return fn(buffs) - } - - de.mu.Lock() - if de.expired { - de.mu.Unlock() - return errExpired - } - - // if heartbeat disabled, kick off pathfinder - if de.heartbeatDisabled { - if !de.pathFinderRunning { - de.startPathFinder() - } - } - - now := mono.Now() - udpAddr, derpAddr, startWGPing := de.addrForSendLocked(now) - - if de.isWireguardOnly { - if startWGPing { - de.sendWireGuardOnlyPingsLocked(now) - } - } else if !udpAddr.IsValid() || now.After(de.trustBestAddrUntil) { - de.sendDiscoPingsLocked(now, true) - } - de.noteActiveLocked() - de.mu.Unlock() - - if !udpAddr.IsValid() && !derpAddr.IsValid() { - return errNoUDPOrDERP - } - var err error - if udpAddr.IsValid() { - _, err = de.c.sendUDPBatch(udpAddr, buffs) - // TODO(raggi): needs updating for accuracy, as in error conditions we may have partial sends. - if stats := de.c.stats.Load(); err == nil && stats != nil { - var txBytes int - for _, b := range buffs { - txBytes += len(b) - } - stats.UpdateTxPhysical(de.nodeAddr, udpAddr, txBytes) - } - } - if derpAddr.IsValid() { - allOk := true - for _, buff := range buffs { - ok, _ := de.c.sendAddr(derpAddr, de.publicKey, buff) - if stats := de.c.stats.Load(); stats != nil { - stats.UpdateTxPhysical(de.nodeAddr, derpAddr, len(buff)) - } - if !ok { - allOk = false - } - } - if allOk { - return nil - } - } - return err -} - -func (de *endpoint) discoPingTimeout(txid stun.TxID) { - de.mu.Lock() - defer de.mu.Unlock() - sp, ok := de.sentPing[txid] - if !ok { - return - } - if debugDisco() || !de.bestAddr.IsValid() || mono.Now().After(de.trustBestAddrUntil) { - de.c.dlogf("[v1] magicsock: disco: timeout waiting for pong %x from %v (%v, %v)", txid[:6], sp.to, de.publicKey.ShortString(), de.discoShort()) - } - de.removeSentDiscoPingLocked(txid, sp) -} - -// forgetDiscoPing is called by a timer when a ping either fails to send or -// has taken too long to get a pong reply. -func (de *endpoint) forgetDiscoPing(txid stun.TxID) { - de.mu.Lock() - defer de.mu.Unlock() - if sp, ok := de.sentPing[txid]; ok { - de.removeSentDiscoPingLocked(txid, sp) - } -} - -func (de *endpoint) removeSentDiscoPingLocked(txid stun.TxID, sp sentPing) { - // Stop the timer for the case where sendPing failed to write to UDP. - // In the case of a timer already having fired, this is a no-op: - sp.timer.Stop() - delete(de.sentPing, txid) -} - -// sendDiscoPing sends a ping with the provided txid to ep using de's discoKey. -// -// The caller (startPingLocked) should've already recorded the ping in -// sentPing and set up the timer. -// -// The caller should use de.discoKey as the discoKey argument. -// It is passed in so that sendDiscoPing doesn't need to lock de.mu. -func (de *endpoint) sendDiscoPing(ep netip.AddrPort, discoKey key.DiscoPublic, txid stun.TxID, logLevel discoLogLevel) { - sent, _ := de.c.sendDiscoMessage(ep, de.publicKey, discoKey, &disco.Ping{ - TxID: [12]byte(txid), - NodeKey: de.c.publicKeyAtomic.Load(), - }, logLevel) - if !sent { - de.forgetDiscoPing(txid) - } -} - -// discoPingPurpose is the reason why a discovery ping message was sent. -type discoPingPurpose int - -//go:generate go run tailscale.com/cmd/addlicense -file discopingpurpose_string.go go run golang.org/x/tools/cmd/stringer -type=discoPingPurpose -trimprefix=ping -const ( - // pingDiscovery means that purpose of a ping was to see if a - // path was valid. - pingDiscovery discoPingPurpose = iota - - // pingHeartbeat means that purpose of a ping was whether a - // peer was still there. - pingHeartbeat - - // pingCLI means that the user is running "tailscale ping" - // from the CLI. These types of pings can go over DERP. - pingCLI -) - -func (de *endpoint) startDiscoPingLocked(ep netip.AddrPort, now mono.Time, purpose discoPingPurpose) { - if runtime.GOOS == "js" { - return - } - epDisco := de.disco.Load() - if epDisco == nil { - return - } - if purpose != pingCLI { - st, ok := de.endpointState[ep] - if !ok { - // Shouldn't happen. But don't ping an endpoint that's - // not active for us. - de.c.logf("magicsock: disco: [unexpected] attempt to ping no longer live endpoint %v", ep) - return - } - st.lastPing = now - } - - txid := stun.NewTxID() - de.sentPing[txid] = sentPing{ - to: ep, - at: now, - timer: time.AfterFunc(pingTimeoutDuration, func() { de.discoPingTimeout(txid) }), - purpose: purpose, - } - logLevel := discoLog - if purpose == pingHeartbeat { - logLevel = discoVerboseLog - } - go de.sendDiscoPing(ep, epDisco.key, txid, logLevel) -} - -func (de *endpoint) sendDiscoPingsLocked(now mono.Time, sendCallMeMaybe bool) { - de.lastFullPing = now - var sentAny bool - for ep, st := range de.endpointState { - if st.shouldDeleteLocked() { - de.deleteEndpointLocked("sendPingsLocked", ep) - continue - } - if runtime.GOOS == "js" { - continue - } - if !st.lastPing.IsZero() && now.Sub(st.lastPing) < discoPingInterval { - continue - } - - firstPing := !sentAny - sentAny = true - - if firstPing && sendCallMeMaybe { - de.c.dlogf("[v1] magicsock: disco: send, starting discovery for %v (%v)", de.publicKey.ShortString(), de.discoShort()) - } - - de.startDiscoPingLocked(ep, now, pingDiscovery) - } - derpAddr := de.derpAddr - if sentAny && sendCallMeMaybe && derpAddr.IsValid() { - // Have our magicsock.Conn figure out its STUN endpoint (if - // it doesn't know already) and then send a CallMeMaybe - // message to our peer via DERP informing them that we've - // sent so our firewall ports are probably open and now - // would be a good time for them to connect. - go de.c.enqueueCallMeMaybe(derpAddr, de) - } -} - -// sendWireGuardOnlyPingsLocked evaluates all available addresses for -// a WireGuard only endpoint and initates an ICMP ping for useable -// addresses. -func (de *endpoint) sendWireGuardOnlyPingsLocked(now mono.Time) { - if runtime.GOOS == "js" { - return - } - - // Normally the we only send pings at a low rate as the decision to start - // sending a ping sets bestAddrAtUntil with a reasonable time to keep trying - // that address, however, if that code changed we may want to be sure that - // we don't ever send excessive pings to avoid impact to the client/user. - if !now.After(de.lastFullPing.Add(10 * time.Second)) { - return - } - de.lastFullPing = now - - for ipp := range de.endpointState { - if ipp.Addr().Is4() && de.c.noV4.Load() { - continue - } - if ipp.Addr().Is6() && de.c.noV6.Load() { - continue - } - - go de.sendWireGuardOnlyPing(ipp, now) - } -} - -// getPinger lazily instantiates a pinger and returns it, if it was -// already instantiated it returns the existing one. -func (c *Conn) getPinger() *ping.Pinger { - return c.wgPinger.Get(func() *ping.Pinger { - return ping.New(c.connCtx, c.dlogf, netns.Listener(c.logf, c.netMon)) - }) -} - -// sendWireGuardOnlyPing sends a ICMP ping to a WireGuard only address to -// discover the latency. -func (de *endpoint) sendWireGuardOnlyPing(ipp netip.AddrPort, now mono.Time) { - ctx, cancel := context.WithTimeout(de.c.connCtx, 5*time.Second) - defer cancel() - - de.setLastPing(ipp, now) - - addr := &net.IPAddr{ - IP: net.IP(ipp.Addr().AsSlice()), - Zone: ipp.Addr().Zone(), - } - - p := de.c.getPinger() - if p == nil { - de.c.logf("[v2] magicsock: sendWireGuardOnlyPingLocked: pinger is nil") - return - } - - latency, err := p.Send(ctx, addr, nil) - if err != nil { - de.c.logf("[v2] magicsock: sendWireGuardOnlyPingLocked: %s", err) - return - } - - de.mu.Lock() - defer de.mu.Unlock() - - state, ok := de.endpointState[ipp] - if !ok { - return - } - state.addPongReplyLocked(pongReply{ - latency: latency, - pongAt: now, - from: ipp, - pongSrc: netip.AddrPort{}, // We don't know this. - }) -} - -// setLastPing sets lastPing on the endpointState to now. -func (de *endpoint) setLastPing(ipp netip.AddrPort, now mono.Time) { - de.mu.Lock() - defer de.mu.Unlock() - state, ok := de.endpointState[ipp] - if !ok { - return - } - state.lastPing = now -} - -// updateFromNode updates the endpoint based on a tailcfg.Node from a NetMap -// update. -func (de *endpoint) updateFromNode(n *tailcfg.Node, heartbeatDisabled bool) { - if n == nil { - panic("nil node when updating endpoint") - } - de.mu.Lock() - defer de.mu.Unlock() - - de.heartbeatDisabled = heartbeatDisabled - de.expired = n.Expired - - epDisco := de.disco.Load() - var discoKey key.DiscoPublic - if epDisco != nil { - discoKey = epDisco.key - } - - if discoKey != n.DiscoKey { - de.c.logf("[v1] magicsock: disco: node %s changed from %s to %s", de.publicKey.ShortString(), discoKey, n.DiscoKey) - de.disco.Store(&endpointDisco{ - key: n.DiscoKey, - short: n.DiscoKey.ShortString(), - }) - de.debugUpdates.Add(EndpointChange{ - When: time.Now(), - What: "updateFromNode-resetLocked", - }) - de.resetLocked() - } - if n.DERP == "" { - if de.derpAddr.IsValid() { - de.debugUpdates.Add(EndpointChange{ - When: time.Now(), - What: "updateFromNode-remove-DERP", - From: de.derpAddr, - }) - } - de.derpAddr = netip.AddrPort{} - } else { - newDerp, _ := netip.ParseAddrPort(n.DERP) - if de.derpAddr != newDerp { - de.debugUpdates.Add(EndpointChange{ - When: time.Now(), - What: "updateFromNode-DERP", - From: de.derpAddr, - To: newDerp, - }) - } - de.derpAddr = newDerp - } - - for _, st := range de.endpointState { - st.index = indexSentinelDeleted // assume deleted until updated in next loop - } - - var newIpps []netip.AddrPort - for i, epStr := range n.Endpoints { - if i > math.MaxInt16 { - // Seems unlikely. - continue - } - ipp, err := netip.ParseAddrPort(epStr) - if err != nil { - de.c.logf("magicsock: bogus netmap endpoint %q", epStr) - continue - } - if st, ok := de.endpointState[ipp]; ok { - st.index = int16(i) - } else { - de.endpointState[ipp] = &endpointState{index: int16(i)} - newIpps = append(newIpps, ipp) - } - } - if len(newIpps) > 0 { - de.debugUpdates.Add(EndpointChange{ - When: time.Now(), - What: "updateFromNode-new-Endpoints", - To: newIpps, - }) - } - - // Now delete anything unless it's still in the network map or - // was a recently discovered endpoint. - for ep, st := range de.endpointState { - if st.shouldDeleteLocked() { - de.deleteEndpointLocked("updateFromNode", ep) - } - } - - // Node changed. Invalidate its sending fast path, if any. - de.sendFunc.Store(nil) -} - -// addCandidateEndpoint adds ep as an endpoint to which we should send -// future pings. If there is an existing endpointState for ep, and forRxPingTxID -// matches the last received ping TxID, this function reports true, otherwise -// false. -// -// This is called once we've already verified that we got a valid -// discovery message from de via ep. -func (de *endpoint) addCandidateEndpoint(ep netip.AddrPort, forRxPingTxID stun.TxID) (duplicatePing bool) { - de.mu.Lock() - defer de.mu.Unlock() - - if st, ok := de.endpointState[ep]; ok { - duplicatePing = forRxPingTxID == st.lastGotPingTxID - if !duplicatePing { - st.lastGotPingTxID = forRxPingTxID - } - if st.lastGotPing.IsZero() { - // Already-known endpoint from the network map. - return duplicatePing - } - st.lastGotPing = time.Now() - return duplicatePing - } - - // Newly discovered endpoint. Exciting! - de.c.dlogf("[v1] magicsock: disco: adding %v as candidate endpoint for %v (%s)", ep, de.discoShort(), de.publicKey.ShortString()) - de.endpointState[ep] = &endpointState{ - lastGotPing: time.Now(), - lastGotPingTxID: forRxPingTxID, - } - - // If for some reason this gets very large, do some cleanup. - if size := len(de.endpointState); size > 100 { - for ep, st := range de.endpointState { - if st.shouldDeleteLocked() { - de.deleteEndpointLocked("addCandidateEndpoint", ep) - } - } - size2 := len(de.endpointState) - de.c.dlogf("[v1] magicsock: disco: addCandidateEndpoint pruned %v candidate set from %v to %v entries", size, size2) - } - return false -} - -// noteConnectivityChange is called when connectivity changes enough -// that we should question our earlier assumptions about which paths -// work. -func (de *endpoint) noteConnectivityChange() { - de.mu.Lock() - defer de.mu.Unlock() - - de.trustBestAddrUntil = 0 -} - -// handlePongConnLocked handles a Pong message (a reply to an earlier ping). -// It should be called with the Conn.mu held. -// -// It reports whether m.TxID corresponds to a ping that this endpoint sent. -func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src netip.AddrPort) (knownTxID bool) { - de.mu.Lock() - defer de.mu.Unlock() - - isDerp := src.Addr() == derpMagicIPAddr - - sp, ok := de.sentPing[m.TxID] - if !ok { - // This is not a pong for a ping we sent. - return false - } - knownTxID = true // for naked returns below - de.removeSentDiscoPingLocked(m.TxID, sp) - - now := mono.Now() - latency := now.Sub(sp.at) - - if !isDerp { - st, ok := de.endpointState[sp.to] - if !ok { - // This is no longer an endpoint we care about. - return - } - - de.c.peerMap.setNodeKeyForIPPort(src, de.publicKey) - - st.addPongReplyLocked(pongReply{ - latency: latency, - pongAt: now, - from: src, - pongSrc: m.Src, - }) - } - - if sp.purpose != pingHeartbeat { - de.c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got pong tx=%x latency=%v pong.src=%v%v", de.c.discoShort, de.discoShort(), de.publicKey.ShortString(), src, m.TxID[:6], latency.Round(time.Millisecond), m.Src, logger.ArgWriter(func(bw *bufio.Writer) { - if sp.to != src { - fmt.Fprintf(bw, " ping.to=%v", sp.to) - } - })) - } - - for _, pp := range de.pendingCLIPings { - de.c.populateCLIPingResponseLocked(pp.res, latency, sp.to) - go pp.cb(pp.res) - } - de.pendingCLIPings = nil - - // Promote this pong response to our current best address if it's lower latency. - // TODO(bradfitz): decide how latency vs. preference order affects decision - if !isDerp { - thisPong := addrLatency{sp.to, latency} - if betterAddr(thisPong, de.bestAddr) { - de.c.logf("magicsock: disco: node %v %v now using %v", de.publicKey.ShortString(), de.discoShort(), sp.to) - de.debugUpdates.Add(EndpointChange{ - When: time.Now(), - What: "handlePingLocked-bestAddr-update", - From: de.bestAddr, - To: thisPong, - }) - de.bestAddr = thisPong - } - if de.bestAddr.AddrPort == thisPong.AddrPort { - de.debugUpdates.Add(EndpointChange{ - When: time.Now(), - What: "handlePingLocked-bestAddr-latency", - From: de.bestAddr, - To: thisPong, - }) - de.bestAddr.latency = latency - de.bestAddrAt = now - de.trustBestAddrUntil = now.Add(trustUDPAddrDuration) - } - } - return +// getPinger lazily instantiates a pinger and returns it, if it was +// already instantiated it returns the existing one. +func (c *Conn) getPinger() *ping.Pinger { + return c.wgPinger.Get(func() *ping.Pinger { + return ping.New(c.connCtx, c.dlogf, netns.Listener(c.logf, c.netMon)) + }) } // portableTrySetSocketBuffer sets SO_SNDBUF and SO_RECVBUF on pconn to socketBufferSize, @@ -5069,229 +2729,6 @@ func portableTrySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) { } } -// addrLatency is an IPPort with an associated latency. -type addrLatency struct { - netip.AddrPort - latency time.Duration -} - -func (a addrLatency) String() string { - return a.AddrPort.String() + "@" + a.latency.String() -} - -// betterAddr reports whether a is a better addr to use than b. -func betterAddr(a, b addrLatency) bool { - if a.AddrPort == b.AddrPort { - return false - } - if !b.IsValid() { - return true - } - if !a.IsValid() { - return false - } - - // Each address starts with a set of points (from 0 to 100) that - // represents how much faster they are than the highest-latency - // endpoint. For example, if a has latency 200ms and b has latency - // 190ms, then a starts with 0 points and b starts with 5 points since - // it's 5% faster. - var aPoints, bPoints int - if a.latency > b.latency && a.latency > 0 { - bPoints = int(100 - ((b.latency * 100) / a.latency)) - } else if b.latency > 0 { - aPoints = int(100 - ((a.latency * 100) / b.latency)) - } - - // Prefer private IPs over public IPs as long as the latencies are - // roughly equivalent, since it's less likely that a user will have to - // pay for the bandwidth in a cloud environment. - // - // Additionally, prefer any loopback address strongly over non-loopback - // addresses. - if a.Addr().IsLoopback() { - aPoints += 50 - } else if a.Addr().IsPrivate() { - aPoints += 20 - } - if b.Addr().IsLoopback() { - bPoints += 50 - } else if b.Addr().IsPrivate() { - bPoints += 20 - } - - // Prefer IPv6 for being a bit more robust, as long as - // the latencies are roughly equivalent. - if a.Addr().Is6() { - aPoints += 10 - } - if b.Addr().Is6() { - bPoints += 10 - } - - // Don't change anything if the latency improvement is less than 1%; we - // want a bit of "stickiness" (a.k.a. hysteresis) to avoid flapping if - // there's two roughly-equivalent endpoints. - // - // Points are essentially the percentage improvement of latency vs. the - // slower endpoint; absent any boosts from private IPs, IPv6, etc., a - // will be a better address than b by a fraction of 1% or less if - // aPoints <= 1 and bPoints == 0. - if aPoints <= 1 && bPoints == 0 { - return false - } - - return aPoints > bPoints -} - -// endpoint.mu must be held. -func (st *endpointState) addPongReplyLocked(r pongReply) { - if n := len(st.recentPongs); n < pongHistoryCount { - st.recentPong = uint16(n) - st.recentPongs = append(st.recentPongs, r) - return - } - i := st.recentPong + 1 - if i == pongHistoryCount { - i = 0 - } - st.recentPongs[i] = r - st.recentPong = i -} - -// handleCallMeMaybe handles a CallMeMaybe discovery message via -// DERP. The contract for use of this message is that the peer has -// already sent to us via UDP, so their stateful firewall should be -// open. Now we can Ping back and make it through. -func (de *endpoint) handleCallMeMaybe(m *disco.CallMeMaybe) { - if runtime.GOOS == "js" { - // Nothing to do on js/wasm if we can't send UDP packets anyway. - return - } - de.mu.Lock() - defer de.mu.Unlock() - - now := time.Now() - for ep := range de.isCallMeMaybeEP { - de.isCallMeMaybeEP[ep] = false // mark for deletion - } - var newEPs []netip.AddrPort - for _, ep := range m.MyNumber { - if ep.Addr().Is6() && ep.Addr().IsLinkLocalUnicast() { - // We send these out, but ignore them for now. - // TODO: teach the ping code to ping on all interfaces - // for these. - continue - } - mak.Set(&de.isCallMeMaybeEP, ep, true) - if es, ok := de.endpointState[ep]; ok { - es.callMeMaybeTime = now - } else { - de.endpointState[ep] = &endpointState{callMeMaybeTime: now} - newEPs = append(newEPs, ep) - } - } - if len(newEPs) > 0 { - de.debugUpdates.Add(EndpointChange{ - When: time.Now(), - What: "handleCallMeMaybe-new-endpoints", - To: newEPs, - }) - - de.c.dlogf("[v1] magicsock: disco: call-me-maybe from %v %v added new endpoints: %v", - de.publicKey.ShortString(), de.discoShort(), - logger.ArgWriter(func(w *bufio.Writer) { - for i, ep := range newEPs { - if i > 0 { - w.WriteString(", ") - } - w.WriteString(ep.String()) - } - })) - } - - // Delete any prior CallMeMaybe endpoints that weren't included - // in this message. - for ep, want := range de.isCallMeMaybeEP { - if !want { - delete(de.isCallMeMaybeEP, ep) - de.deleteEndpointLocked("handleCallMeMaybe", ep) - } - } - - // Zero out all the lastPing times to force sendPingsLocked to send new ones, - // even if it's been less than 5 seconds ago. - for _, st := range de.endpointState { - st.lastPing = 0 - } - de.sendDiscoPingsLocked(mono.Now(), false) -} - -func (de *endpoint) populatePeerStatus(ps *ipnstate.PeerStatus) { - de.mu.Lock() - defer de.mu.Unlock() - - ps.Relay = de.c.derpRegionCodeOfIDLocked(int(de.derpAddr.Port())) - - if de.lastSend.IsZero() { - return - } - - now := mono.Now() - ps.LastWrite = de.lastSend.WallTime() - ps.Active = now.Sub(de.lastSend) < sessionActiveTimeout - - if udpAddr, derpAddr, _ := de.addrForSendLocked(now); udpAddr.IsValid() && !derpAddr.IsValid() { - ps.CurAddr = udpAddr.String() - } -} - -// stopAndReset stops timers associated with de and resets its state back to zero. -// It's called when a discovery endpoint is no longer present in the -// NetworkMap, or when magicsock is transitioning from running to -// stopped state (via SetPrivateKey(zero)) -func (de *endpoint) stopAndReset() { - atomic.AddInt64(&de.numStopAndResetAtomic, 1) - de.mu.Lock() - defer de.mu.Unlock() - - if closing := de.c.closing.Load(); !closing { - de.c.logf("[v1] magicsock: doing cleanup for discovery key %s", de.discoShort()) - } - - de.debugUpdates.Add(EndpointChange{ - When: time.Now(), - What: "stopAndReset-resetLocked", - }) - de.resetLocked() - if de.heartBeatTimer != nil { - de.heartBeatTimer.Stop() - de.heartBeatTimer = nil - } - de.pendingCLIPings = nil -} - -// resetLocked clears all the endpoint's p2p state, reverting it to a -// DERP-only endpoint. It does not stop the endpoint's heartbeat -// timer, if one is running. -func (de *endpoint) resetLocked() { - de.lastSend = 0 - de.lastFullPing = 0 - de.bestAddr = addrLatency{} - de.bestAddrAt = 0 - de.trustBestAddrUntil = 0 - for _, es := range de.endpointState { - es.lastPing = 0 - } - for txid, sp := range de.sentPing { - de.removeSentDiscoPingLocked(txid, sp) - } -} - -func (de *endpoint) numStopAndReset() int64 { - return atomic.LoadInt64(&de.numStopAndResetAtomic) -} - // derpStr replaces DERP IPs in s with "derp-". func derpStr(s string) string { return strings.ReplaceAll(s, "127.3.3.40:", "derp-") } @@ -5335,95 +2772,6 @@ type discoInfo struct { lastPingTime time.Time } -// derpAddrFamSelector is the derphttp.AddressFamilySelector we pass -// to derphttp.Client.SetAddressFamilySelector. -// -// It provides the hint as to whether in an IPv4-vs-IPv6 race that -// IPv4 should be held back a bit to give IPv6 a better-than-50/50 -// chance of winning. We only return true when we believe IPv6 will -// work anyway, so we don't artificially delay the connection speed. -type derpAddrFamSelector struct{ c *Conn } - -func (s derpAddrFamSelector) PreferIPv6() bool { - if r := s.c.lastNetCheckReport.Load(); r != nil { - return r.IPv6 - } - return false -} - -type endpointTrackerEntry struct { - endpoint tailcfg.Endpoint - until time.Time -} - -type endpointTracker struct { - mu sync.Mutex - cache map[netip.AddrPort]endpointTrackerEntry -} - -func (et *endpointTracker) update(now time.Time, eps []tailcfg.Endpoint) (epsPlusCached []tailcfg.Endpoint) { - epsPlusCached = eps - - var inputEps set.Slice[netip.AddrPort] - for _, ep := range eps { - inputEps.Add(ep.Addr) - } - - et.mu.Lock() - defer et.mu.Unlock() - - // Add entries to the return array that aren't already there. - for k, ep := range et.cache { - // If the endpoint was in the input list, or has expired, skip it. - if inputEps.Contains(k) { - continue - } else if now.After(ep.until) { - continue - } - - // We haven't seen this endpoint; add to the return array - epsPlusCached = append(epsPlusCached, ep.endpoint) - } - - // Add entries from the original input array into the cache, and/or - // extend the lifetime of entries that are already in the cache. - until := now.Add(endpointTrackerLifetime) - for _, ep := range eps { - et.addLocked(now, ep, until) - } - - // Remove everything that has now expired. - et.removeExpiredLocked(now) - return epsPlusCached -} - -// add will store the provided endpoint(s) in the cache for a fixed period of -// time, and remove any entries in the cache that have expired. -// -// et.mu must be held. -func (et *endpointTracker) addLocked(now time.Time, ep tailcfg.Endpoint, until time.Time) { - // If we already have an entry for this endpoint, update the timeout on - // it; otherwise, add it. - entry, found := et.cache[ep.Addr] - if found { - entry.until = until - } else { - entry = endpointTrackerEntry{ep, until} - } - mak.Set(&et.cache, ep.Addr, entry) -} - -// removeExpired will remove all expired entries from the cache -// -// et.mu must be held -func (et *endpointTracker) removeExpiredLocked(now time.Time) { - for k, ep := range et.cache { - if now.After(ep.until) { - delete(et.cache, k) - } - } -} - var ( metricNumPeers = clientmetric.NewGauge("magicsock_netmap_num_peers") metricNumDERPConns = clientmetric.NewGauge("magicsock_num_derp_conns") diff --git a/vendor/tailscale.com/wgengine/magicsock/magicsock_default.go b/vendor/tailscale.com/wgengine/magicsock/magicsock_default.go index 4dda3c8a65..87075e5226 100644 --- a/vendor/tailscale.com/wgengine/magicsock/magicsock_default.go +++ b/vendor/tailscale.com/wgengine/magicsock/magicsock_default.go @@ -31,10 +31,6 @@ func getGSOSizeFromControl(control []byte) (int, error) { func setGSOSizeInControl(control *[]byte, gso uint16) {} -func errShouldDisableOffload(err error) bool { - return false -} - const ( controlMessageSize = 0 ) diff --git a/vendor/tailscale.com/wgengine/magicsock/magicsock_linux.go b/vendor/tailscale.com/wgengine/magicsock/magicsock_linux.go index cdfbeb7590..a4101ccbaa 100644 --- a/vendor/tailscale.com/wgengine/magicsock/magicsock_linux.go +++ b/vendor/tailscale.com/wgengine/magicsock/magicsock_linux.go @@ -303,11 +303,11 @@ func trySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) { rc.Control(func(fd uintptr) { errRcv = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_RCVBUFFORCE, socketBufferSize) if errRcv != nil { - logf("magicsock: failed to force-set UDP read buffer size to %d: %v", socketBufferSize, errRcv) + logf("magicsock: [warning] failed to force-set UDP read buffer size to %d: %v; using kernel default values (impacts throughput only)", socketBufferSize, errRcv) } errSnd = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_SNDBUFFORCE, socketBufferSize) if errSnd != nil { - logf("magicsock: failed to force-set UDP write buffer size to %d: %v", socketBufferSize, errSnd) + logf("magicsock: [warning] failed to force-set UDP write buffer size to %d: %v; using kernel default values (impacts throughput only)", socketBufferSize, errSnd) } }) } diff --git a/vendor/tailscale.com/wgengine/magicsock/pathfinder.go b/vendor/tailscale.com/wgengine/magicsock/pathfinder.go deleted file mode 100644 index 830709d24d..0000000000 --- a/vendor/tailscale.com/wgengine/magicsock/pathfinder.go +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package magicsock - -// startPathFinder initializes the atomicSendFunc, and -// will eventually kick off a goroutine that monitors whether -// that sendFunc is still the best option for the endpoint -// to use and adjusts accordingly. -func (de *endpoint) startPathFinder() { - de.pathFinderRunning = true -} diff --git a/vendor/tailscale.com/wgengine/magicsock/peermap.go b/vendor/tailscale.com/wgengine/magicsock/peermap.go new file mode 100644 index 0000000000..cacba57281 --- /dev/null +++ b/vendor/tailscale.com/wgengine/magicsock/peermap.go @@ -0,0 +1,208 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "net/netip" + + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/util/set" +) + +// peerInfo is all the information magicsock tracks about a particular +// peer. +type peerInfo struct { + ep *endpoint // always non-nil. + // ipPorts is an inverted version of peerMap.byIPPort (below), so + // that when we're deleting this node, we can rapidly find out the + // keys that need deleting from peerMap.byIPPort without having to + // iterate over every IPPort known for any peer. + ipPorts set.Set[netip.AddrPort] +} + +func newPeerInfo(ep *endpoint) *peerInfo { + return &peerInfo{ + ep: ep, + ipPorts: set.Set[netip.AddrPort]{}, + } +} + +// peerMap is an index of peerInfos by node (WireGuard) key, disco +// key, and discovered ip:port endpoints. +// +// It doesn't do any locking; all access must be done with Conn.mu held. +type peerMap struct { + byNodeKey map[key.NodePublic]*peerInfo + byIPPort map[netip.AddrPort]*peerInfo + byNodeID map[tailcfg.NodeID]*peerInfo + + // nodesOfDisco contains the set of nodes that are using a + // DiscoKey. Usually those sets will be just one node. + nodesOfDisco map[key.DiscoPublic]set.Set[key.NodePublic] +} + +func newPeerMap() peerMap { + return peerMap{ + byNodeKey: map[key.NodePublic]*peerInfo{}, + byIPPort: map[netip.AddrPort]*peerInfo{}, + byNodeID: map[tailcfg.NodeID]*peerInfo{}, + nodesOfDisco: map[key.DiscoPublic]set.Set[key.NodePublic]{}, + } +} + +// nodeCount returns the number of nodes currently in m. +func (m *peerMap) nodeCount() int { + if len(m.byNodeKey) != len(m.byNodeID) { + devPanicf("internal error: peerMap.byNodeKey and byNodeID out of sync") + } + return len(m.byNodeKey) +} + +// anyEndpointForDiscoKey reports whether there exists any +// peers in the netmap with dk as their DiscoKey. +func (m *peerMap) anyEndpointForDiscoKey(dk key.DiscoPublic) bool { + return len(m.nodesOfDisco[dk]) > 0 +} + +// endpointForNodeKey returns the endpoint for nk, or nil if +// nk is not known to us. +func (m *peerMap) endpointForNodeKey(nk key.NodePublic) (ep *endpoint, ok bool) { + if nk.IsZero() { + return nil, false + } + if info, ok := m.byNodeKey[nk]; ok { + return info.ep, true + } + return nil, false +} + +// endpointForNodeID returns the endpoint for nodeID, or nil if +// nodeID is not known to us. +func (m *peerMap) endpointForNodeID(nodeID tailcfg.NodeID) (ep *endpoint, ok bool) { + if info, ok := m.byNodeID[nodeID]; ok { + return info.ep, true + } + return nil, false +} + +// endpointForIPPort returns the endpoint for the peer we +// believe to be at ipp, or nil if we don't know of any such peer. +func (m *peerMap) endpointForIPPort(ipp netip.AddrPort) (ep *endpoint, ok bool) { + if info, ok := m.byIPPort[ipp]; ok { + return info.ep, true + } + return nil, false +} + +// forEachEndpoint invokes f on every endpoint in m. +func (m *peerMap) forEachEndpoint(f func(ep *endpoint)) { + for _, pi := range m.byNodeKey { + f(pi.ep) + } +} + +// forEachEndpointWithDiscoKey invokes f on every endpoint in m that has the +// provided DiscoKey until f returns false or there are no endpoints left to +// iterate. +func (m *peerMap) forEachEndpointWithDiscoKey(dk key.DiscoPublic, f func(*endpoint) (keepGoing bool)) { + for nk := range m.nodesOfDisco[dk] { + pi, ok := m.byNodeKey[nk] + if !ok { + // Unexpected. Data structures would have to + // be out of sync. But we don't have a logger + // here to log [unexpected], so just skip. + // Maybe log later once peerMap is merged back + // into Conn. + continue + } + if !f(pi.ep) { + return + } + } +} + +// upsertEndpoint stores endpoint in the peerInfo for +// ep.publicKey, and updates indexes. m must already have a +// tailcfg.Node for ep.publicKey. +func (m *peerMap) upsertEndpoint(ep *endpoint, oldDiscoKey key.DiscoPublic) { + if ep.nodeID == 0 { + panic("internal error: upsertEndpoint called with zero NodeID") + } + pi, ok := m.byNodeKey[ep.publicKey] + if !ok { + pi = newPeerInfo(ep) + m.byNodeKey[ep.publicKey] = pi + } + m.byNodeID[ep.nodeID] = pi + + epDisco := ep.disco.Load() + if epDisco == nil || oldDiscoKey != epDisco.key { + delete(m.nodesOfDisco[oldDiscoKey], ep.publicKey) + } + if ep.isWireguardOnly { + // If the peer is a WireGuard only peer, add all of its endpoints. + + // TODO(raggi,catzkorn): this could mean that if a "isWireguardOnly" + // peer has, say, 192.168.0.2 and so does a tailscale peer, the + // wireguard one will win. That may not be the outcome that we want - + // perhaps we should prefer bestAddr.AddrPort if it is set? + // see tailscale/tailscale#7994 + for ipp := range ep.endpointState { + m.setNodeKeyForIPPort(ipp, ep.publicKey) + } + return + } + discoSet := m.nodesOfDisco[epDisco.key] + if discoSet == nil { + discoSet = set.Set[key.NodePublic]{} + m.nodesOfDisco[epDisco.key] = discoSet + } + discoSet.Add(ep.publicKey) +} + +// setNodeKeyForIPPort makes future peer lookups by ipp return the +// same endpoint as a lookup by nk. +// +// This should only be called with a fully verified mapping of ipp to +// nk, because calling this function defines the endpoint we hand to +// WireGuard for packets received from ipp. +func (m *peerMap) setNodeKeyForIPPort(ipp netip.AddrPort, nk key.NodePublic) { + if pi := m.byIPPort[ipp]; pi != nil { + delete(pi.ipPorts, ipp) + delete(m.byIPPort, ipp) + } + if pi, ok := m.byNodeKey[nk]; ok { + pi.ipPorts.Add(ipp) + m.byIPPort[ipp] = pi + } +} + +// deleteEndpoint deletes the peerInfo associated with ep, and +// updates indexes. +func (m *peerMap) deleteEndpoint(ep *endpoint) { + if ep == nil { + return + } + ep.stopAndReset() + + epDisco := ep.disco.Load() + + pi := m.byNodeKey[ep.publicKey] + if epDisco != nil { + delete(m.nodesOfDisco[epDisco.key], ep.publicKey) + } + delete(m.byNodeKey, ep.publicKey) + if was, ok := m.byNodeID[ep.nodeID]; ok && was.ep == ep { + delete(m.byNodeID, ep.nodeID) + } + if pi == nil { + // Kneejerk paranoia from earlier issue 2801. + // Unexpected. But no logger plumbed here to log so. + return + } + for ip := range pi.ipPorts { + delete(m.byIPPort, ip) + } +} diff --git a/vendor/tailscale.com/wgengine/magicsock/peermtu.go b/vendor/tailscale.com/wgengine/magicsock/peermtu.go new file mode 100644 index 0000000000..8013aa5ea0 --- /dev/null +++ b/vendor/tailscale.com/wgengine/magicsock/peermtu.go @@ -0,0 +1,107 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (darwin && !ios) || (linux && !android) + +package magicsock + +// Peer path MTU routines shared by platforms that implement it. + +// DontFragSetting returns true if at least one of the underlying sockets of +// this connection is a UDP socket with the don't fragment bit set, otherwise it +// returns false. It also returns an error if either connection returned an error +// other than errUnsupportedConnType. +func (c *Conn) DontFragSetting() (bool, error) { + df4, err4 := c.getDontFragment("udp4") + df6, err6 := c.getDontFragment("udp6") + df := df4 || df6 + err := err4 + if err4 != nil && err4 != errUnsupportedConnType { + err = err6 + } + if err == errUnsupportedConnType { + err = nil + } + return df, err +} + +// ShouldPMTUD returns true if this client should try to enable peer MTU +// discovery, false otherwise. +func (c *Conn) ShouldPMTUD() bool { + if v, ok := debugEnablePMTUD().Get(); ok { + if debugPMTUD() { + c.logf("magicsock: peermtu: peer path MTU discovery set via envknob to %v", v) + } + return v + } + if c.controlKnobs != nil { + if v := c.controlKnobs.PeerMTUEnable.Load(); v { + if debugPMTUD() { + c.logf("magicsock: peermtu: peer path MTU discovery enabled by control") + } + return v + } + } + if debugPMTUD() { + c.logf("magicsock: peermtu: peer path MTU discovery set by default to false") + } + return false // Until we feel confident PMTUD is solid. +} + +// PeerMTUEnabled reports whether peer path MTU discovery is enabled. +func (c *Conn) PeerMTUEnabled() bool { + return c.peerMTUEnabled.Load() +} + +// UpdatePMTUD configures the underlying sockets of this Conn to enable or disable +// peer path MTU discovery according to the current configuration. +// +// Enabling or disabling peer path MTU discovery requires setting the don't +// fragment bit on its two underlying pconns. There are three distinct results +// for this operation on each pconn: +// +// 1. Success +// 2. Failure (not supported on this platform, or supported but failed) +// 3. Not a UDP socket (most likely one of IPv4 or IPv6 couldn't be used) +// +// To simplify the fast path for the most common case, we set the PMTUD status +// of the overall Conn according to the results of setting the sockopt on pconn +// as follows: +// +// 1. Both setsockopts succeed: PMTUD status update succeeds +// 2. One succeeds, one returns not a UDP socket: PMTUD status update succeeds +// 4. Neither setsockopt succeeds: PMTUD disabled +// 3. Either setsockopt fails: PMTUD disabled +// +// If the PMTUD settings changed, it resets the endpoint state so that it will +// re-probe path MTUs to this peer. +func (c *Conn) UpdatePMTUD() { + if debugPMTUD() { + df4, err4 := c.getDontFragment("udp4") + df6, err6 := c.getDontFragment("udp6") + c.logf("magicsock: peermtu: peer MTU status %v DF bit status: v4: %v (%v) v6: %v (%v)", c.peerMTUEnabled.Load(), df4, err4, df6, err6) + } + + enable := c.ShouldPMTUD() + if c.peerMTUEnabled.Load() == enable { + c.logf("[v1] magicsock: peermtu: peer MTU status is %v", enable) + return + } + + newStatus := enable + err4 := c.setDontFragment("udp4", enable) + err6 := c.setDontFragment("udp6", enable) + anySuccess := err4 == nil || err6 == nil + noFailures := (err4 == nil || err4 == errUnsupportedConnType) && (err6 == nil || err6 == errUnsupportedConnType) + + if anySuccess && noFailures { + c.logf("magicsock: peermtu: peer MTU status updated to %v", newStatus) + } else { + c.logf("[unexpected] magicsock: peermtu: updating peer MTU status to %v failed (v4: %v, v6: %v), disabling", enable, err4, err6) + _ = c.setDontFragment("udp4", false) + _ = c.setDontFragment("udp6", false) + newStatus = false + } + c.peerMTUEnabled.Store(newStatus) + c.resetEndpointStates() +} diff --git a/vendor/tailscale.com/wgengine/magicsock/peermtu_darwin.go b/vendor/tailscale.com/wgengine/magicsock/peermtu_darwin.go new file mode 100644 index 0000000000..a0a1aacb55 --- /dev/null +++ b/vendor/tailscale.com/wgengine/magicsock/peermtu_darwin.go @@ -0,0 +1,51 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin && !ios + +package magicsock + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +func getDontFragOpt(network string) int { + if network == "udp4" { + return unix.IP_DONTFRAG + } + return unix.IPV6_DONTFRAG +} + +func (c *Conn) setDontFragment(network string, enable bool) error { + optArg := 1 + if enable == false { + optArg = 0 + } + var err error + rcErr := c.connControl(network, func(fd uintptr) { + err = syscall.SetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network), optArg) + }) + + if rcErr != nil { + return rcErr + } + return err +} + +func (c *Conn) getDontFragment(network string) (bool, error) { + var v int + var err error + rcErr := c.connControl(network, func(fd uintptr) { + v, err = syscall.GetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network)) + }) + + if rcErr != nil { + return false, rcErr + } + if v == 1 { + return true, err + } + return false, err +} diff --git a/vendor/tailscale.com/wgengine/magicsock/peermtu_linux.go b/vendor/tailscale.com/wgengine/magicsock/peermtu_linux.go new file mode 100644 index 0000000000..b76f30f081 --- /dev/null +++ b/vendor/tailscale.com/wgengine/magicsock/peermtu_linux.go @@ -0,0 +1,49 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !android + +package magicsock + +import ( + "syscall" +) + +func getDontFragOpt(network string) int { + if network == "udp4" { + return syscall.IP_MTU_DISCOVER + } + return syscall.IPV6_MTU_DISCOVER +} + +func (c *Conn) setDontFragment(network string, enable bool) error { + optArg := syscall.IP_PMTUDISC_DO + if enable == false { + optArg = syscall.IP_PMTUDISC_DONT + } + var err error + rcErr := c.connControl(network, func(fd uintptr) { + err = syscall.SetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network), optArg) + }) + + if rcErr != nil { + return rcErr + } + return err +} + +func (c *Conn) getDontFragment(network string) (bool, error) { + var v int + var err error + rcErr := c.connControl(network, func(fd uintptr) { + v, err = syscall.GetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network)) + }) + + if rcErr != nil { + return false, rcErr + } + if v == syscall.IP_PMTUDISC_DO { + return true, err + } + return false, err +} diff --git a/vendor/tailscale.com/wgengine/magicsock/peermtu_stubs.go b/vendor/tailscale.com/wgengine/magicsock/peermtu_stubs.go new file mode 100644 index 0000000000..6981f28c35 --- /dev/null +++ b/vendor/tailscale.com/wgengine/magicsock/peermtu_stubs.go @@ -0,0 +1,46 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (!linux && !darwin) || android || ios + +package magicsock + +import ( + "errors" +) + +// setDontFragment sets the don't fragment sockopt on the underlying connection +// specified by network, which must be "udp4" or "udp6". See +// https://datatracker.ietf.org/doc/html/rfc3542#section-11.2 for details on +// IPv6 fragmentation. +// +// Return values: +// - an error if peer MTU is not supported on this OS +// - errNoActiveUDP if the underlying connection is not UDP +// - otherwise, the result of setting the don't fragment bit +func (c *Conn) setDontFragment(network string, enable bool) error { + return errors.New("peer path MTU discovery not supported on this OS") +} + +// getDontFragment gets the don't fragment setting on the underlying connection +// specified by network, which must be "udp4" or "udp6". Returns true if the +// underlying connection is UDP and the don't fragment bit is set, otherwise +// false. +func (c *Conn) getDontFragment(network string) (bool, error) { + return false, nil +} + +func (c *Conn) DontFragSetting() (bool, error) { + return false, nil +} + +func (c *Conn) ShouldPMTUD() bool { + return false +} + +func (c *Conn) PeerMTUEnabled() bool { + return false +} + +func (c *Conn) UpdatePMTUD() { +} diff --git a/vendor/tailscale.com/wgengine/magicsock/peermtu_unix.go b/vendor/tailscale.com/wgengine/magicsock/peermtu_unix.go new file mode 100644 index 0000000000..eec3d744f3 --- /dev/null +++ b/vendor/tailscale.com/wgengine/magicsock/peermtu_unix.go @@ -0,0 +1,42 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (darwin && !ios) || (linux && !android) + +package magicsock + +import ( + "syscall" +) + +// getIPProto returns the value of the get/setsockopt proto argument necessary +// to set an IP sockopt that corresponds with the string network, which must be +// "udp4" or "udp6". +func getIPProto(network string) int { + if network == "udp4" { + return syscall.IPPROTO_IP + } + return syscall.IPPROTO_IPV6 +} + +// connControl allows the caller to run a system call on the socket underlying +// Conn specified by the string network, which must be "udp4" or "udp6". If the +// pconn type implements the syscall method, this function returns the value of +// of the system call fn called with the fd of the socket as its arg (or the +// error from rc.Control() if that fails). Otherwise it returns the error +// errUnsupportedConnType. +func (c *Conn) connControl(network string, fn func(fd uintptr)) error { + pconn := c.pconn4.pconn + if network == "udp6" { + pconn = c.pconn6.pconn + } + sc, ok := pconn.(syscall.Conn) + if !ok { + return errUnsupportedConnType + } + rc, err := sc.SyscallConn() + if err != nil { + return err + } + return rc.Control(fn) +} diff --git a/vendor/tailscale.com/wgengine/magicsock/rebinding_conn.go b/vendor/tailscale.com/wgengine/magicsock/rebinding_conn.go new file mode 100644 index 0000000000..f1e47f3a8b --- /dev/null +++ b/vendor/tailscale.com/wgengine/magicsock/rebinding_conn.go @@ -0,0 +1,179 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "errors" + "net" + "net/netip" + "sync" + "sync/atomic" + "syscall" + + "golang.org/x/net/ipv6" + "tailscale.com/net/netaddr" + "tailscale.com/types/nettype" +) + +// RebindingUDPConn is a UDP socket that can be re-bound. +// Unix has no notion of re-binding a socket, so we swap it out for a new one. +type RebindingUDPConn struct { + // pconnAtomic is a pointer to the value stored in pconn, but doesn't + // require acquiring mu. It's used for reads/writes and only upon failure + // do the reads/writes then check pconn (after acquiring mu) to see if + // there's been a rebind meanwhile. + // pconn isn't really needed, but makes some of the code simpler + // to keep it distinct. + // Neither is expected to be nil, sockets are bound on creation. + pconnAtomic atomic.Pointer[nettype.PacketConn] + + mu sync.Mutex // held while changing pconn (and pconnAtomic) + pconn nettype.PacketConn + port uint16 +} + +// setConnLocked sets the provided nettype.PacketConn. It should be called only +// after acquiring RebindingUDPConn.mu. It upgrades the provided +// nettype.PacketConn to a *batchingUDPConn when appropriate. This upgrade +// is intentionally pushed closest to where read/write ops occur in order to +// avoid disrupting surrounding code that assumes nettype.PacketConn is a +// *net.UDPConn. +func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn, network string, batchSize int) { + upc := tryUpgradeToBatchingUDPConn(p, network, batchSize) + c.pconn = upc + c.pconnAtomic.Store(&upc) + c.port = uint16(c.localAddrLocked().Port) +} + +// currentConn returns c's current pconn, acquiring c.mu in the process. +func (c *RebindingUDPConn) currentConn() nettype.PacketConn { + c.mu.Lock() + defer c.mu.Unlock() + return c.pconn +} + +func (c *RebindingUDPConn) readFromWithInitPconn(pconn nettype.PacketConn, b []byte) (int, netip.AddrPort, error) { + for { + n, addr, err := pconn.ReadFromUDPAddrPort(b) + if err != nil && pconn != c.currentConn() { + pconn = *c.pconnAtomic.Load() + continue + } + return n, addr, err + } +} + +// ReadFromUDPAddrPort reads a packet from c into b. +// It returns the number of bytes copied and the source address. +func (c *RebindingUDPConn) ReadFromUDPAddrPort(b []byte) (int, netip.AddrPort, error) { + return c.readFromWithInitPconn(*c.pconnAtomic.Load(), b) +} + +// WriteBatchTo writes buffs to addr. +func (c *RebindingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error { + for { + pconn := *c.pconnAtomic.Load() + b, ok := pconn.(*batchingUDPConn) + if !ok { + for _, buf := range buffs { + _, err := c.writeToUDPAddrPortWithInitPconn(pconn, buf, addr) + if err != nil { + return err + } + } + return nil + } + err := b.WriteBatchTo(buffs, addr) + if err != nil { + if pconn != c.currentConn() { + continue + } + return err + } + return err + } +} + +// ReadBatch reads messages from c into msgs. It returns the number of messages +// the caller should evaluate for nonzero len, as a zero len message may fall +// on either side of a nonzero. +func (c *RebindingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (int, error) { + for { + pconn := *c.pconnAtomic.Load() + b, ok := pconn.(*batchingUDPConn) + if !ok { + n, ap, err := c.readFromWithInitPconn(pconn, msgs[0].Buffers[0]) + if err == nil { + msgs[0].N = n + msgs[0].Addr = net.UDPAddrFromAddrPort(netaddr.Unmap(ap)) + return 1, nil + } + return 0, err + } + n, err := b.ReadBatch(msgs, flags) + if err != nil && pconn != c.currentConn() { + continue + } + return n, err + } +} + +func (c *RebindingUDPConn) Port() uint16 { + c.mu.Lock() + defer c.mu.Unlock() + return c.port +} + +func (c *RebindingUDPConn) LocalAddr() *net.UDPAddr { + c.mu.Lock() + defer c.mu.Unlock() + return c.localAddrLocked() +} + +func (c *RebindingUDPConn) localAddrLocked() *net.UDPAddr { + return c.pconn.LocalAddr().(*net.UDPAddr) +} + +// errNilPConn is returned by RebindingUDPConn.Close when there is no current pconn. +// It is for internal use only and should not be returned to users. +var errNilPConn = errors.New("nil pconn") + +func (c *RebindingUDPConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + return c.closeLocked() +} + +func (c *RebindingUDPConn) closeLocked() error { + if c.pconn == nil { + return errNilPConn + } + c.port = 0 + return c.pconn.Close() +} + +func (c *RebindingUDPConn) writeToUDPAddrPortWithInitPconn(pconn nettype.PacketConn, b []byte, addr netip.AddrPort) (int, error) { + for { + n, err := pconn.WriteToUDPAddrPort(b, addr) + if err != nil && pconn != c.currentConn() { + pconn = *c.pconnAtomic.Load() + continue + } + return n, err + } +} + +func (c *RebindingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { + return c.writeToUDPAddrPortWithInitPconn(*c.pconnAtomic.Load(), b, addr) +} + +func (c *RebindingUDPConn) SyscallConn() (syscall.RawConn, error) { + c.mu.Lock() + defer c.mu.Unlock() + sc, ok := c.pconn.(syscall.Conn) + if !ok { + return nil, errUnsupportedConnType + } + return sc.SyscallConn() +} diff --git a/vendor/tailscale.com/wgengine/netlog/logger.go b/vendor/tailscale.com/wgengine/netlog/logger.go index a694308e60..3dd02afb96 100644 --- a/vendor/tailscale.com/wgengine/netlog/logger.go +++ b/vendor/tailscale.com/wgengine/netlog/logger.go @@ -101,7 +101,8 @@ func (nl *Logger) Startup(nodeID tailcfg.StableNodeID, nodeLogID, domainLogID lo } // Startup a log stream to Tailscale's logging service. - httpc := &http.Client{Transport: logpolicy.NewLogtailTransport(logtail.DefaultHost, netMon)} + logf := log.Printf + httpc := &http.Client{Transport: logpolicy.NewLogtailTransport(logtail.DefaultHost, netMon, logf)} if testClient != nil { httpc = testClient } @@ -123,7 +124,7 @@ func (nl *Logger) Startup(nodeID tailcfg.StableNodeID, nodeLogID, domainLogID lo // Include process sequence numbers to identify missing samples. IncludeProcID: true, IncludeProcSequence: true, - }, log.Printf) + }, logf) nl.logger.SetSockstatsLabel(sockstats.LabelNetlogLogger) // Startup a data structure to track per-connection statistics. diff --git a/vendor/tailscale.com/wgengine/netstack/netstack.go b/vendor/tailscale.com/wgengine/netstack/netstack.go index 54e9918014..e2f930b4c2 100644 --- a/vendor/tailscale.com/wgengine/netstack/netstack.go +++ b/vendor/tailscale.com/wgengine/netstack/netstack.go @@ -43,7 +43,9 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/net/tsdial" "tailscale.com/net/tstun" + "tailscale.com/proxymap" "tailscale.com/syncs" + "tailscale.com/tailcfg" "tailscale.com/types/ipproto" "tailscale.com/types/logger" "tailscale.com/types/netmap" @@ -120,6 +122,7 @@ type Impl struct { linkEP *channel.Endpoint tundev *tstun.Wrapper e wgengine.Engine + pm *proxymap.Mapper mc *magicsock.Conn logf logger.Logf dialer *tsdial.Dialer @@ -153,7 +156,7 @@ const nicID = 1 const maxUDPPacketSize = 1500 // Create creates and populates a new Impl. -func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magicsock.Conn, dialer *tsdial.Dialer, dns *dns.Manager) (*Impl, error) { +func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magicsock.Conn, dialer *tsdial.Dialer, dns *dns.Manager, pm *proxymap.Mapper) (*Impl, error) { if mc == nil { return nil, errors.New("nil magicsock.Conn") } @@ -166,6 +169,9 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi if e == nil { return nil, errors.New("nil Engine") } + if pm == nil { + return nil, errors.New("nil proxymap.Mapper") + } if dialer == nil { return nil, errors.New("nil Dialer") } @@ -208,13 +214,14 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi linkEP: linkEP, tundev: tundev, e: e, + pm: pm, mc: mc, dialer: dialer, connsOpenBySubnetIP: make(map[netip.Addr]int), dns: dns, } ns.ctx, ns.ctxCancel = context.WithCancel(context.Background()) - ns.atomicIsLocalIPFunc.Store(tsaddr.NewContainsIPFunc(nil)) + ns.atomicIsLocalIPFunc.Store(tsaddr.FalseContainsIPFunc()) ns.tundev.PostFilterPacketInboundFromWireGaurd = ns.injectInbound ns.tundev.PreFilterPacketOutboundToWireGuardNetstackIntercept = ns.handleLocalPackets return ns, nil @@ -253,7 +260,6 @@ func (ns *Impl) Start(lb *ipnlocal.LocalBackend) error { panic("nil LocalBackend") } ns.lb = lb - ns.e.AddNetworkMapCallback(ns.updateIPs) // size = 0 means use default buffer size const tcpReceiveBufferSize = 0 const maxInFlightConnectionAttempts = 1024 @@ -310,8 +316,19 @@ func ipPrefixToAddressWithPrefix(ipp netip.Prefix) tcpip.AddressWithPrefix { var v4broadcast = netaddr.IPv4(255, 255, 255, 255) -func (ns *Impl) updateIPs(nm *netmap.NetworkMap) { - ns.atomicIsLocalIPFunc.Store(tsaddr.NewContainsIPFunc(nm.Addresses)) +// UpdateNetstackIPs updates the set of local IPs that netstack should handle +// from nm. +// +// TODO(bradfitz): don't pass the whole netmap here; just pass the two +// address slice views. +func (ns *Impl) UpdateNetstackIPs(nm *netmap.NetworkMap) { + var selfNode tailcfg.NodeView + if nm != nil { + ns.atomicIsLocalIPFunc.Store(tsaddr.NewContainsIPFunc(nm.GetAddresses())) + selfNode = nm.SelfNode + } else { + ns.atomicIsLocalIPFunc.Store(tsaddr.FalseContainsIPFunc()) + } oldIPs := make(map[tcpip.AddressWithPrefix]bool) for _, protocolAddr := range ns.ipstack.AllAddresses()[nicID] { @@ -328,12 +345,14 @@ func (ns *Impl) updateIPs(nm *netmap.NetworkMap) { newIPs := make(map[tcpip.AddressWithPrefix]bool) isAddr := map[netip.Prefix]bool{} - if nm.SelfNode != nil { - for _, ipp := range nm.SelfNode.Addresses { + if selfNode.Valid() { + for i := range selfNode.Addresses().LenIter() { + ipp := selfNode.Addresses().At(i) isAddr[ipp] = true newIPs[ipPrefixToAddressWithPrefix(ipp)] = true } - for _, ipp := range nm.SelfNode.AllowedIPs { + for i := range selfNode.AllowedIPs().LenIter() { + ipp := selfNode.AllowedIPs().At(i) if !isAddr[ipp] && ns.ProcessSubnets { newIPs[ipPrefixToAddressWithPrefix(ipp)] = true } @@ -971,8 +990,8 @@ func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet. backendLocalAddr := server.LocalAddr().(*net.TCPAddr) backendLocalIPPort := netaddr.Unmap(backendLocalAddr.AddrPort()) - ns.e.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP) - defer ns.e.UnregisterIPPortIdentity(backendLocalIPPort) + ns.pm.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP) + defer ns.pm.UnregisterIPPortIdentity(backendLocalIPPort) connClosed := make(chan error, 2) go func() { _, err := io.Copy(server, client) @@ -1122,7 +1141,7 @@ func (ns *Impl) forwardUDP(client *gonet.UDPConn, clientAddr, dstAddr netip.Addr ns.logf("could not get backend local IP:port from %v:%v", backendLocalAddr.IP, backendLocalAddr.Port) } if isLocal { - ns.e.RegisterIPPortIdentity(backendLocalIPPort, dstAddr.Addr()) + ns.pm.RegisterIPPortIdentity(backendLocalIPPort, dstAddr.Addr()) } ctx, cancel := context.WithCancel(context.Background()) @@ -1138,7 +1157,7 @@ func (ns *Impl) forwardUDP(client *gonet.UDPConn, clientAddr, dstAddr netip.Addr } timer := time.AfterFunc(idleTimeout, func() { if isLocal { - ns.e.UnregisterIPPortIdentity(backendLocalIPPort) + ns.pm.UnregisterIPPortIdentity(backendLocalIPPort) } ns.logf("netstack: UDP session between %s and %s timed out", backendListenAddr, backendRemoteAddr) cancel() diff --git a/vendor/tailscale.com/wgengine/pendopen.go b/vendor/tailscale.com/wgengine/pendopen.go index f21ef75ec5..af5d3d8a7d 100644 --- a/vendor/tailscale.com/wgengine/pendopen.go +++ b/vendor/tailscale.com/wgengine/pendopen.go @@ -146,19 +146,22 @@ func (e *userspaceEngine) onOpenTimeout(flow flowtrack.Tuple) { return } n := pip.Node - if n.DiscoKey.IsZero() { - e.logf("open-conn-track: timeout opening %v; peer node %v running pre-0.100", flow, n.Key.ShortString()) - return - } - if n.DERP == "" { - e.logf("open-conn-track: timeout opening %v; peer node %v not connected to any DERP relay", flow, n.Key.ShortString()) - return + if !n.IsWireGuardOnly() { + if n.DiscoKey().IsZero() { + e.logf("open-conn-track: timeout opening %v; peer node %v running pre-0.100", flow, n.Key().ShortString()) + return + } + if n.DERP() == "" { + e.logf("open-conn-track: timeout opening %v; peer node %v not connected to any DERP relay", flow, n.Key().ShortString()) + return + } } - ps, found := e.getPeerStatusLite(n.Key) + ps, found := e.getPeerStatusLite(n.Key()) if !found { onlyZeroRoute := true // whether peerForIP returned n only because its /0 route matched - for _, r := range n.AllowedIPs { + for i := range n.AllowedIPs().LenIter() { + r := n.AllowedIPs().At(i) if r.Bits() != 0 && r.Contains(flow.Dst.Addr()) { onlyZeroRoute = false break @@ -176,7 +179,7 @@ func (e *userspaceEngine) onOpenTimeout(flow flowtrack.Tuple) { // node. return } - e.logf("open-conn-track: timeout opening %v; target node %v in netmap but unknown to WireGuard", flow, n.Key.ShortString()) + e.logf("open-conn-track: timeout opening %v; target node %v in netmap but unknown to WireGuard", flow, n.Key().ShortString()) return } @@ -187,20 +190,24 @@ func (e *userspaceEngine) onOpenTimeout(flow flowtrack.Tuple) { _ = ps.LastHandshake online := "?" - if n.Online != nil { - if *n.Online { - online = "yes" - } else { - online = "no" + if n.IsWireGuardOnly() { + online = "wg" + } else { + if v := n.Online(); v != nil { + if *v { + online = "yes" + } else { + online = "no" + } + } + if n.LastSeen() != nil && online != "yes" { + online += fmt.Sprintf(", lastseen=%v", durFmt(*n.LastSeen())) } - } - if n.LastSeen != nil && online != "yes" { - online += fmt.Sprintf(", lastseen=%v", durFmt(*n.LastSeen)) } e.logf("open-conn-track: timeout opening %v to node %v; online=%v, lastRecv=%v", - flow, n.Key.ShortString(), + flow, n.Key().ShortString(), online, - e.magicConn.LastRecvActivityOfNodeKey(n.Key)) + e.magicConn.LastRecvActivityOfNodeKey(n.Key())) } func durFmt(t time.Time) string { diff --git a/vendor/tailscale.com/wgengine/router/callback.go b/vendor/tailscale.com/wgengine/router/callback.go index 1d7f328467..c1838539ba 100644 --- a/vendor/tailscale.com/wgengine/router/callback.go +++ b/vendor/tailscale.com/wgengine/router/callback.go @@ -24,9 +24,16 @@ type CallbackRouter struct { // will return ErrGetBaseConfigNotSupported. GetBaseConfigFunc func() (dns.OSConfig, error) - mu sync.Mutex // protects all the following - rcfg *Config // last applied router config - dcfg *dns.OSConfig // last applied DNS config + // InitialMTU is the MTU the tun should be initialized with. + // Zero means don't change the MTU from the default. This MTU + // is applied only once, shortly after the TUN is created, and + // ignored thereafter. + InitialMTU uint32 + + mu sync.Mutex // protects all the following + didSetMTU bool // if we set the MTU already + rcfg *Config // last applied router config + dcfg *dns.OSConfig // last applied DNS config } // Up implements Router. @@ -41,6 +48,10 @@ func (r *CallbackRouter) Set(rcfg *Config) error { if r.rcfg.Equal(rcfg) { return nil } + if r.didSetMTU == false { + r.didSetMTU = true + rcfg.NewMTU = int(r.InitialMTU) + } r.rcfg = rcfg return r.SetBoth(r.rcfg, r.dcfg) } diff --git a/vendor/tailscale.com/wgengine/router/ifconfig_windows.go b/vendor/tailscale.com/wgengine/router/ifconfig_windows.go index 1cd01eee14..f6bb21c92c 100644 --- a/vendor/tailscale.com/wgengine/router/ifconfig_windows.go +++ b/vendor/tailscale.com/wgengine/router/ifconfig_windows.go @@ -11,13 +11,13 @@ import ( "log" "net" "net/netip" + "slices" "sort" "time" ole "github.com/go-ole/go-ole" "github.com/tailscale/wireguard-go/tun" "go4.org/netipx" - "golang.org/x/exp/slices" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "tailscale.com/health" @@ -396,7 +396,7 @@ func configureInterface(cfg *Config, tun *tun.NativeTun) (retErr error) { return fmt.Errorf("syncAddresses: %w", err) } - slices.SortFunc(routes, routeDataLess) + slices.SortFunc(routes, routeDataCompare) deduplicatedRoutes := []*winipcfg.RouteData{} for i := 0; i < len(routes); i++ { @@ -652,8 +652,8 @@ func routeDataCompare(a, b *winipcfg.RouteData) int { func deltaRouteData(a, b []*winipcfg.RouteData) (add, del []*winipcfg.RouteData) { add = make([]*winipcfg.RouteData, 0, len(b)) del = make([]*winipcfg.RouteData, 0, len(a)) - slices.SortFunc(a, routeDataLess) - slices.SortFunc(b, routeDataLess) + slices.SortFunc(a, routeDataCompare) + slices.SortFunc(b, routeDataCompare) i := 0 j := 0 diff --git a/vendor/tailscale.com/wgengine/router/router.go b/vendor/tailscale.com/wgengine/router/router.go index 11668a70e9..ecea521ad5 100644 --- a/vendor/tailscale.com/wgengine/router/router.go +++ b/vendor/tailscale.com/wgengine/router/router.go @@ -67,6 +67,11 @@ type Config struct { // routing rules apply. LocalRoutes []netip.Prefix + // NewMTU is currently only used by the MacOS network extension + // app to set the MTU of the tun in the router configuration + // callback. If zero, the MTU is unchanged. + NewMTU int + // Linux-only things below, ignored on other platforms. SubnetRoutes []netip.Prefix // subnets being advertised to other Tailscale nodes SNATSubnetRoutes bool // SNAT traffic to local subnets diff --git a/vendor/tailscale.com/wgengine/router/router_linux.go b/vendor/tailscale.com/wgengine/router/router_linux.go index 85c3799c42..8a7273bd22 100644 --- a/vendor/tailscale.com/wgengine/router/router_linux.go +++ b/vendor/tailscale.com/wgengine/router/router_linux.go @@ -4,7 +4,6 @@ package router import ( - "bytes" "errors" "fmt" "net" @@ -17,17 +16,17 @@ import ( "syscall" "time" - "github.com/coreos/go-iptables/iptables" "github.com/tailscale/netlink" "github.com/tailscale/wireguard-go/tun" "go4.org/netipx" "golang.org/x/sys/unix" "golang.org/x/time/rate" "tailscale.com/envknob" + "tailscale.com/hostinfo" "tailscale.com/net/netmon" - "tailscale.com/net/tsaddr" "tailscale.com/types/logger" "tailscale.com/types/preftype" + "tailscale.com/util/linuxfw" "tailscale.com/util/multierr" "tailscale.com/version/distro" ) @@ -38,56 +37,143 @@ const ( netfilterOn = preftype.NetfilterOn ) -// The following bits are added to packet marks for Tailscale use. -// -// We tried to pick bits sufficiently out of the way that it's -// unlikely to collide with existing uses. We have 4 bytes of mark -// bits to play with. We leave the lower byte alone on the assumption -// that sysadmins would use those. Kubernetes uses a few bits in the -// second byte, so we steer clear of that too. -// -// Empirically, most of the documentation on packet marks on the -// internet gives the impression that the marks are 16 bits -// wide. Based on this, we theorize that the upper two bytes are -// relatively unused in the wild, and so we consume bits 16:23 (the -// third byte). -// -// The constants are in the iptables/iproute2 string format for -// matching and setting the bits, so they can be directly embedded in -// commands. -const ( - // The mask for reading/writing the 'firewall mask' bits on a packet. - // See the comment on the const block on why we only use the third byte. - // - // We claim bits 16:23 entirely. For now we only use the lower four - // bits, leaving the higher 4 bits for future use. - tailscaleFwmarkMask = "0xff0000" - tailscaleFwmarkMaskNum = 0xff0000 +// netfilterRunner abstracts helpers to run netfilter commands. It is +// implemented by linuxfw.IPTablesRunner and linuxfw.NfTablesRunner. +type netfilterRunner interface { + AddLoopbackRule(addr netip.Addr) error + DelLoopbackRule(addr netip.Addr) error + AddHooks() error + DelHooks(logf logger.Logf) error + AddChains() error + DelChains() error + AddBase(tunname string) error + DelBase() error + AddSNATRule() error + DelSNATRule() error - // Packet is from Tailscale and to a subnet route destination, so - // is allowed to be routed through this machine. - tailscaleSubnetRouteMark = "0x40000" + HasIPV6() bool + HasIPV6NAT() bool +} - // Packet was originated by tailscaled itself, and must not be - // routed over the Tailscale network. - // - // Keep this in sync with tailscaleBypassMark in - // net/netns/netns_linux.go. - tailscaleBypassMark = "0x80000" - tailscaleBypassMarkNum = 0x80000 -) +// tableDetector abstracts helpers to detect the firewall mode. +// It is implemented for testing purposes. +type tableDetector interface { + iptDetect() (int, error) + nftDetect() (int, error) +} -// netfilterRunner abstracts helpers to run netfilter commands. It -// exists purely to swap out go-iptables for a fake implementation in -// tests. -type netfilterRunner interface { - Insert(table, chain string, pos int, args ...string) error - Append(table, chain string, args ...string) error - Exists(table, chain string, args ...string) (bool, error) - Delete(table, chain string, args ...string) error - ClearChain(table, chain string) error - NewChain(table, chain string) error - DeleteChain(table, chain string) error +type linuxFWDetector struct{} + +// iptDetect returns the number of iptables rules in the current namespace. +func (l *linuxFWDetector) iptDetect() (int, error) { + return linuxfw.DetectIptables() +} + +// nftDetect returns the number of nftables rules in the current namespace. +func (l *linuxFWDetector) nftDetect() (int, error) { + return linuxfw.DetectNetfilter() +} + +// chooseFireWallMode returns the firewall mode to use based on the +// environment and the system's capabilities. +func chooseFireWallMode(logf logger.Logf, det tableDetector) linuxfw.FirewallMode { + if distro.Get() == distro.Gokrazy { + // Reduce startup logging on gokrazy. There's no way to do iptables on + // gokrazy anyway. + return linuxfw.FirewallModeNfTables + } + iptAva, nftAva := true, true + iptRuleCount, err := det.iptDetect() + if err != nil { + logf("detect iptables rule: %v", err) + iptAva = false + } + nftRuleCount, err := det.nftDetect() + if err != nil { + logf("detect nftables rule: %v", err) + nftAva = false + } + logf("nftables rule count: %d, iptables rule count: %d", nftRuleCount, iptRuleCount) + switch { + case nftRuleCount > 0 && iptRuleCount == 0: + logf("nftables is currently in use") + hostinfo.SetFirewallMode("nft-inuse") + return linuxfw.FirewallModeNfTables + case iptRuleCount > 0 && nftRuleCount == 0: + logf("iptables is currently in use") + hostinfo.SetFirewallMode("ipt-inuse") + return linuxfw.FirewallModeIPTables + case nftAva: + // if both iptables and nftables are available but + // neither/both are currently used, use nftables. + logf("nftables is available") + hostinfo.SetFirewallMode("nft") + return linuxfw.FirewallModeNfTables + case iptAva: + logf("iptables is available") + hostinfo.SetFirewallMode("ipt") + return linuxfw.FirewallModeIPTables + default: + // if neither iptables nor nftables are available, use iptablesRunner as a dummy + // runner which exists but won't do anything. Creating iptablesRunner errors only + // if the iptables command is missing or doesn’t support "--version", as long as it + // can determine a version then it’ll carry on. + hostinfo.SetFirewallMode("ipt-fb") + return linuxfw.FirewallModeIPTables + } +} + +// newNetfilterRunner creates a netfilterRunner using either nftables or iptables. +// As nftables is still experimental, iptables will be used unless TS_DEBUG_USE_NETLINK_NFTABLES is set. +func newNetfilterRunner(logf logger.Logf) (netfilterRunner, error) { + tableDetector := &linuxFWDetector{} + var mode linuxfw.FirewallMode + + // We now use iptables as default and have "auto" and "nftables" as + // options for people to test further. + switch { + case distro.Get() == distro.Gokrazy: + // Reduce startup logging on gokrazy. There's no way to do iptables on + // gokrazy anyway. + logf("GoKrazy should use nftables.") + hostinfo.SetFirewallMode("nft-gokrazy") + mode = linuxfw.FirewallModeNfTables + case envknob.String("TS_DEBUG_FIREWALL_MODE") == "nftables": + logf("envknob TS_DEBUG_FIREWALL_MODE=nftables set") + hostinfo.SetFirewallMode("nft-forced") + mode = linuxfw.FirewallModeNfTables + case envknob.String("TS_DEBUG_FIREWALL_MODE") == "auto": + mode = chooseFireWallMode(logf, tableDetector) + case envknob.String("TS_DEBUG_FIREWALL_MODE") == "iptables": + logf("envknob TS_DEBUG_FIREWALL_MODE=iptables set") + hostinfo.SetFirewallMode("ipt-forced") + mode = linuxfw.FirewallModeIPTables + default: + logf("default choosing iptables") + hostinfo.SetFirewallMode("ipt-default") + mode = linuxfw.FirewallModeIPTables + } + + var nfr netfilterRunner + var err error + switch mode { + case linuxfw.FirewallModeIPTables: + logf("using iptables") + nfr, err = linuxfw.NewIPTablesRunner(logf) + if err != nil { + return nil, err + } + case linuxfw.FirewallModeNfTables: + logf("using nftables") + nfr, err = linuxfw.NewNfTablesRunner(logf) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("unknown firewall mode: %v", mode) + } + + return nfr, nil } type linuxRouter struct { @@ -109,16 +195,13 @@ type linuxRouter struct { // Various feature checks for the network stack. ipRuleAvailable bool // whether kernel was built with IP_MULTIPLE_TABLES - v6Available bool - v6NATAvailable bool fwmaskWorks bool // whether we can use 'ip rule...fwmark /' // ipPolicyPrefBase is the base priority at which ip rules are installed. ipPolicyPrefBase int - ipt4 netfilterRunner - ipt6 netfilterRunner - cmd commandRunner + nfr netfilterRunner + cmd commandRunner } func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Monitor) (Router, error) { @@ -127,51 +210,27 @@ func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Moni return nil, err } - ipt4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + nfr, err := newNetfilterRunner(logf) if err != nil { return nil, err } - v6err := checkIPv6(logf) - if v6err != nil { - logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err) - } - supportsV6 := v6err == nil - supportsV6NAT := supportsV6 && supportsV6NAT() - if supportsV6 { - logf("v6nat = %v", supportsV6NAT) - } - - var ipt6 netfilterRunner - if supportsV6 { - // The iptables package probes for `ip6tables` and errors out - // if unavailable. We want that to be a non-fatal error. - ipt6, err = iptables.NewWithProtocol(iptables.ProtocolIPv6) - if err != nil { - return nil, err - } - } - cmd := osCommandRunner{ ambientCapNetAdmin: useAmbientCaps(), } - return newUserspaceRouterAdvanced(logf, tunname, netMon, ipt4, ipt6, cmd, supportsV6, supportsV6NAT) + return newUserspaceRouterAdvanced(logf, tunname, netMon, nfr, cmd) } -func newUserspaceRouterAdvanced(logf logger.Logf, tunname string, netMon *netmon.Monitor, netfilter4, netfilter6 netfilterRunner, cmd commandRunner, supportsV6, supportsV6NAT bool) (Router, error) { +func newUserspaceRouterAdvanced(logf logger.Logf, tunname string, netMon *netmon.Monitor, nfr netfilterRunner, cmd commandRunner) (Router, error) { r := &linuxRouter{ logf: logf, tunname: tunname, netfilterMode: netfilterOff, netMon: netMon, - v6Available: supportsV6, - v6NATAvailable: supportsV6NAT, - - ipt4: netfilter4, - ipt6: netfilter6, - cmd: cmd, + nfr: nfr, + cmd: cmd, ipRuleFixLimiter: rate.NewLimiter(rate.Every(5*time.Second), 10), ipPolicyPrefBase: 5200, @@ -484,23 +543,23 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error { case netfilterOff: switch r.netfilterMode { case netfilterNoDivert: - if err := r.delNetfilterBase(); err != nil { + if err := r.nfr.DelBase(); err != nil { return err } - if err := r.delNetfilterChains(); err != nil { + if err := r.nfr.DelChains(); err != nil { r.logf("note: %v", err) // harmless, continue. // This can happen if someone left a ref to // this table somewhere else. } case netfilterOn: - if err := r.delNetfilterHooks(); err != nil { + if err := r.nfr.DelHooks(r.logf); err != nil { return err } - if err := r.delNetfilterBase(); err != nil { + if err := r.nfr.DelBase(); err != nil { return err } - if err := r.delNetfilterChains(); err != nil { + if err := r.nfr.DelChains(); err != nil { r.logf("note: %v", err) // harmless, continue. // This can happen if someone left a ref to @@ -512,15 +571,15 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error { switch r.netfilterMode { case netfilterOff: reprocess = true - if err := r.addNetfilterChains(); err != nil { + if err := r.nfr.AddChains(); err != nil { return err } - if err := r.addNetfilterBase(); err != nil { + if err := r.nfr.AddBase(r.tunname); err != nil { return err } r.snatSubnetRoutes = false case netfilterOn: - if err := r.delNetfilterHooks(); err != nil { + if err := r.nfr.DelHooks(r.logf); err != nil { return err } } @@ -529,33 +588,35 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error { // we can't add a "-j ts-forward" rule to FORWARD // while ts-forward contains an "-m mark" rule. But // we can add the row *before* populating ts-forward. - // So we have to delNetFilterBase, then add the hooks, - // then re-addNetFilterBase, just in case. + // So we have to delBase, then add the hooks, + // then re-addBase, just in case. switch r.netfilterMode { case netfilterOff: reprocess = true - if err := r.addNetfilterChains(); err != nil { + if err := r.nfr.AddChains(); err != nil { return err } - if err := r.delNetfilterBase(); err != nil { + if err := r.nfr.DelBase(); err != nil { return err } - if err := r.addNetfilterHooks(); err != nil { + // AddHooks adds the ts loopback rule. + if err := r.nfr.AddHooks(); err != nil { return err } - if err := r.addNetfilterBase(); err != nil { + // AddBase adds base ts rules + if err := r.nfr.AddBase(r.tunname); err != nil { return err } r.snatSubnetRoutes = false case netfilterNoDivert: reprocess = true - if err := r.delNetfilterBase(); err != nil { + if err := r.nfr.DelBase(); err != nil { return err } - if err := r.addNetfilterHooks(); err != nil { + if err := r.nfr.AddHooks(); err != nil { return err } - if err := r.addNetfilterBase(); err != nil { + if err := r.nfr.AddBase(r.tunname); err != nil { return err } r.snatSubnetRoutes = false @@ -579,11 +640,19 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error { return nil } +func (r *linuxRouter) getV6Available() bool { + return r.nfr.HasIPV6() +} + +func (r *linuxRouter) getV6NATAvailable() bool { + return r.nfr.HasIPV6NAT() +} + // addAddress adds an IP/mask to the tunnel interface. Fails if the // address is already assigned to the interface, or if the addition // fails. func (r *linuxRouter) addAddress(addr netip.Prefix) error { - if !r.v6Available && addr.Addr().Is6() { + if !r.getV6Available() && addr.Addr().Is6() { return nil } if r.useIPCommand() { @@ -609,7 +678,7 @@ func (r *linuxRouter) addAddress(addr netip.Prefix) error { // the address is not assigned to the interface, or if the removal // fails. func (r *linuxRouter) delAddress(addr netip.Prefix) error { - if !r.v6Available && addr.Addr().Is6() { + if !r.getV6Available() && addr.Addr().Is6() { return nil } if err := r.delLoopbackRule(addr.Addr()); err != nil { @@ -638,17 +707,8 @@ func (r *linuxRouter) addLoopbackRule(addr netip.Addr) error { return nil } - nf := r.ipt4 - if addr.Is6() { - if !r.v6Available { - // IPv6 not available, ignore. - return nil - } - nf = r.ipt6 - } - - if err := nf.Insert("filter", "ts-input", 1, "-i", "lo", "-s", addr.String(), "-j", "ACCEPT"); err != nil { - return fmt.Errorf("adding loopback allow rule for %q: %w", addr, err) + if err := r.nfr.AddLoopbackRule(addr); err != nil { + return err } return nil } @@ -660,17 +720,8 @@ func (r *linuxRouter) delLoopbackRule(addr netip.Addr) error { return nil } - nf := r.ipt4 - if addr.Is6() { - if !r.v6Available { - // IPv6 not available, ignore. - return nil - } - nf = r.ipt6 - } - - if err := nf.Delete("filter", "ts-input", "-i", "lo", "-s", addr.String(), "-j", "ACCEPT"); err != nil { - return fmt.Errorf("deleting loopback allow rule for %q: %w", addr, err) + if err := r.nfr.DelLoopbackRule(addr); err != nil { + return err } return nil } @@ -679,7 +730,7 @@ func (r *linuxRouter) delLoopbackRule(addr netip.Addr) error { // interface. Fails if the route already exists, or if adding the // route fails. func (r *linuxRouter) addRoute(cidr netip.Prefix) error { - if !r.v6Available && cidr.Addr().Is6() { + if !r.getV6Available() && cidr.Addr().Is6() { return nil } if r.useIPCommand() { @@ -704,7 +755,7 @@ func (r *linuxRouter) addThrowRoute(cidr netip.Prefix) error { if !r.ipRuleAvailable { return nil } - if !r.v6Available && cidr.Addr().Is6() { + if !r.getV6Available() && cidr.Addr().Is6() { return nil } if r.useIPCommand() { @@ -712,7 +763,7 @@ func (r *linuxRouter) addThrowRoute(cidr netip.Prefix) error { } err := netlink.RouteReplace(&netlink.Route{ Dst: netipx.PrefixIPNet(cidr.Masked()), - Table: tailscaleRouteTable.num, + Table: tailscaleRouteTable.Num, Type: unix.RTN_THROW, }) if err != nil { @@ -722,7 +773,7 @@ func (r *linuxRouter) addThrowRoute(cidr netip.Prefix) error { } func (r *linuxRouter) addRouteDef(routeDef []string, cidr netip.Prefix) error { - if !r.v6Available && cidr.Addr().Is6() { + if !r.getV6Available() && cidr.Addr().Is6() { return nil } args := append([]string{"ip", "route", "add"}, routeDef...) @@ -756,7 +807,7 @@ var ( // interface. Fails if the route doesn't exist, or if removing the // route fails. func (r *linuxRouter) delRoute(cidr netip.Prefix) error { - if !r.v6Available && cidr.Addr().Is6() { + if !r.getV6Available() && cidr.Addr().Is6() { return nil } if r.useIPCommand() { @@ -784,7 +835,7 @@ func (r *linuxRouter) delThrowRoute(cidr netip.Prefix) error { if !r.ipRuleAvailable { return nil } - if !r.v6Available && cidr.Addr().Is6() { + if !r.getV6Available() && cidr.Addr().Is6() { return nil } if r.useIPCommand() { @@ -803,7 +854,7 @@ func (r *linuxRouter) delThrowRoute(cidr netip.Prefix) error { } func (r *linuxRouter) delRouteDef(routeDef []string, cidr netip.Prefix) error { - if !r.v6Available && cidr.Addr().Is6() { + if !r.getV6Available() && cidr.Addr().Is6() { return nil } args := append([]string{"ip", "route", "del"}, routeDef...) @@ -865,7 +916,7 @@ func (r *linuxRouter) linkIndex() (int, error) { // routeTable returns the route table to use. func (r *linuxRouter) routeTable() int { if r.ipRuleAvailable { - return tailscaleRouteTable.num + return tailscaleRouteTable.Num } return 0 } @@ -962,7 +1013,7 @@ func (f addrFamily) netlinkInt() int { } func (r *linuxRouter) addrFamilies() []addrFamily { - if r.v6Available { + if r.getV6Available() { return []addrFamily{v4, v6} } return []addrFamily{v4} @@ -985,30 +1036,34 @@ func (r *linuxRouter) addIPRules() error { return r.justAddIPRules() } -// routeTable is a Linux routing table: both its name and number. +// RouteTable is a Linux routing table: both its name and number. // See /etc/iproute2/rt_tables. -type routeTable struct { - name string - num int +type RouteTable struct { + Name string + Num int } -// ipCmdArg returns the string form of the table to pass to the "ip" command. -func (rt routeTable) ipCmdArg() string { - if rt.num >= 253 { - return rt.name +var routeTableByNumber = map[int]RouteTable{} + +// IpCmdArg returns the string form of the table to pass to the "ip" command. +func (rt RouteTable) ipCmdArg() string { + if rt.Num >= 253 { + return rt.Name } - return strconv.Itoa(rt.num) + return strconv.Itoa(rt.Num) } -var routeTableByNumber = map[int]routeTable{} - -func newRouteTable(name string, num int) routeTable { - rt := routeTable{name, num} +func newRouteTable(name string, num int) RouteTable { + rt := RouteTable{name, num} routeTableByNumber[num] = rt return rt } -func mustRouteTable(num int) routeTable { +// MustRouteTable returns the RouteTable with the given number key. +// It panics if the number is unknown because this result is a part +// of IP rule argument and we don't want to continue with an invalid +// argument with table no exist. +func mustRouteTable(num int) RouteTable { rt, ok := routeTableByNumber[num] if !ok { panic(fmt.Sprintf("unknown route table %v", num)) @@ -1059,22 +1114,22 @@ var ipRules = []netlink.Rule{ // main routing table. { Priority: 10, - Mark: tailscaleBypassMarkNum, - Table: mainRouteTable.num, + Mark: linuxfw.TailscaleBypassMarkNum, + Table: mainRouteTable.Num, }, // ...and then we try the 'default' table, for correctness, // even though it's been empty on every Linux system I've ever seen. { Priority: 30, - Mark: tailscaleBypassMarkNum, - Table: defaultRouteTable.num, + Mark: linuxfw.TailscaleBypassMarkNum, + Table: defaultRouteTable.Num, }, // If neither of those matched (no default route on this system?) // then packets from us should be aborted rather than falling through // to the tailscale routes, because that would create routing loops. { Priority: 50, - Mark: tailscaleBypassMarkNum, + Mark: linuxfw.TailscaleBypassMarkNum, Type: unix.RTN_UNREACHABLE, }, // If we get to this point, capture all packets and send them @@ -1084,7 +1139,7 @@ var ipRules = []netlink.Rule{ // beat non-VPN routes. { Priority: 70, - Table: tailscaleRouteTable.num, + Table: tailscaleRouteTable.Num, }, // If that didn't match, then non-fwmark packets fall through to the // usual rules (pref 32766 and 32767, ie. main and default). @@ -1105,7 +1160,7 @@ func (r *linuxRouter) justAddIPRules() error { // Note: r is a value type here; safe to mutate it. ru.Family = family.netlinkInt() if ru.Mark != 0 { - ru.Mask = tailscaleFwmarkMaskNum + ru.Mask = linuxfw.TailscaleFwmarkMaskNum } ru.Goto = -1 ru.SuppressIfgroup = -1 @@ -1138,7 +1193,7 @@ func (r *linuxRouter) addIPRulesWithIPCommand() error { } if rule.Mark != 0 { if r.fwmaskWorks { - args = append(args, "fwmark", fmt.Sprintf("0x%x/%s", rule.Mark, tailscaleFwmarkMask)) + args = append(args, "fwmark", fmt.Sprintf("0x%x/%s", rule.Mark, linuxfw.TailscaleFwmarkMask)) } else { args = append(args, "fwmark", fmt.Sprintf("0x%x", rule.Mark)) } @@ -1239,284 +1294,6 @@ func (r *linuxRouter) delIPRulesWithIPCommand() error { return rg.ErrAcc } -func (r *linuxRouter) netfilterFamilies() []netfilterRunner { - if r.v6Available { - return []netfilterRunner{r.ipt4, r.ipt6} - } - return []netfilterRunner{r.ipt4} -} - -// addNetfilterChains creates custom Tailscale chains in netfilter. -func (r *linuxRouter) addNetfilterChains() error { - create := func(ipt netfilterRunner, table, chain string) error { - err := ipt.ClearChain(table, chain) - if errCode(err) == 1 { - // nonexistent chain. let's create it! - return ipt.NewChain(table, chain) - } - if err != nil { - return fmt.Errorf("setting up %s/%s: %w", table, chain, err) - } - return nil - } - - for _, ipt := range r.netfilterFamilies() { - if err := create(ipt, "filter", "ts-input"); err != nil { - return err - } - if err := create(ipt, "filter", "ts-forward"); err != nil { - return err - } - } - if err := create(r.ipt4, "nat", "ts-postrouting"); err != nil { - return err - } - if r.v6NATAvailable { - if err := create(r.ipt6, "nat", "ts-postrouting"); err != nil { - return err - } - } - return nil -} - -// addNetfilterBase adds some basic processing rules to be -// supplemented by later calls to other helpers. -func (r *linuxRouter) addNetfilterBase() error { - if err := r.addNetfilterBase4(); err != nil { - return err - } - if r.v6Available { - if err := r.addNetfilterBase6(); err != nil { - return err - } - } - return nil -} - -// addNetfilterBase4 adds some basic IPv4 processing rules to be -// supplemented by later calls to other helpers. -func (r *linuxRouter) addNetfilterBase4() error { - // Only allow CGNAT range traffic to come from tailscale0. There - // is an exception carved out for ranges used by ChromeOS, for - // which we fall out of the Tailscale chain. - // - // Note, this will definitely break nodes that end up using the - // CGNAT range for other purposes :(. - args := []string{"!", "-i", r.tunname, "-s", tsaddr.ChromeOSVMRange().String(), "-j", "RETURN"} - if err := r.ipt4.Append("filter", "ts-input", args...); err != nil { - return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err) - } - args = []string{"!", "-i", r.tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"} - if err := r.ipt4.Append("filter", "ts-input", args...); err != nil { - return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err) - } - - // Forward all traffic from the Tailscale interface, and drop - // traffic to the tailscale interface by default. We use packet - // marks here so both filter/FORWARD and nat/POSTROUTING can match - // on these packets of interest. - // - // In particular, we only want to apply SNAT rules in - // nat/POSTROUTING to packets that originated from the Tailscale - // interface, but we can't match on the inbound interface in - // POSTROUTING. So instead, we match on the inbound interface in - // filter/FORWARD, and set a packet mark that nat/POSTROUTING can - // use to effectively run that same test again. - args = []string{"-i", r.tunname, "-j", "MARK", "--set-mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask} - if err := r.ipt4.Append("filter", "ts-forward", args...); err != nil { - return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) - } - args = []string{"-m", "mark", "--mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask, "-j", "ACCEPT"} - if err := r.ipt4.Append("filter", "ts-forward", args...); err != nil { - return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) - } - args = []string{"-o", r.tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"} - if err := r.ipt4.Append("filter", "ts-forward", args...); err != nil { - return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) - } - args = []string{"-o", r.tunname, "-j", "ACCEPT"} - if err := r.ipt4.Append("filter", "ts-forward", args...); err != nil { - return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) - } - - return nil -} - -// addNetfilterBase4 adds some basic IPv6 processing rules to be -// supplemented by later calls to other helpers. -func (r *linuxRouter) addNetfilterBase6() error { - // TODO: only allow traffic from Tailscale's ULA range to come - // from tailscale0. - - args := []string{"-i", r.tunname, "-j", "MARK", "--set-mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask} - if err := r.ipt6.Append("filter", "ts-forward", args...); err != nil { - return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err) - } - args = []string{"-m", "mark", "--mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask, "-j", "ACCEPT"} - if err := r.ipt6.Append("filter", "ts-forward", args...); err != nil { - return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err) - } - // TODO: drop forwarded traffic to tailscale0 from tailscale's ULA - // (see corresponding IPv4 CGNAT rule). - args = []string{"-o", r.tunname, "-j", "ACCEPT"} - if err := r.ipt6.Append("filter", "ts-forward", args...); err != nil { - return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err) - } - - return nil -} - -// delNetfilterChains removes the custom Tailscale chains from netfilter. -func (r *linuxRouter) delNetfilterChains() error { - del := func(ipt netfilterRunner, table, chain string) error { - if err := ipt.ClearChain(table, chain); err != nil { - if errCode(err) == 1 { - // nonexistent chain. That's fine, since it's - // the desired state anyway. - return nil - } - return fmt.Errorf("flushing %s/%s: %w", table, chain, err) - } - if err := ipt.DeleteChain(table, chain); err != nil { - // this shouldn't fail, because if the chain didn't - // exist, we would have returned after ClearChain. - return fmt.Errorf("deleting %s/%s: %v", table, chain, err) - } - return nil - } - - for _, ipt := range r.netfilterFamilies() { - if err := del(ipt, "filter", "ts-input"); err != nil { - return err - } - if err := del(ipt, "filter", "ts-forward"); err != nil { - return err - } - } - if err := del(r.ipt4, "nat", "ts-postrouting"); err != nil { - return err - } - if r.v6NATAvailable { - if err := del(r.ipt6, "nat", "ts-postrouting"); err != nil { - return err - } - } - - return nil -} - -// delNetfilterBase empties but does not remove custom Tailscale chains from -// netfilter. -func (r *linuxRouter) delNetfilterBase() error { - del := func(ipt netfilterRunner, table, chain string) error { - if err := ipt.ClearChain(table, chain); err != nil { - if errCode(err) == 1 { - // nonexistent chain. That's fine, since it's - // the desired state anyway. - return nil - } - return fmt.Errorf("flushing %s/%s: %w", table, chain, err) - } - return nil - } - - for _, ipt := range r.netfilterFamilies() { - if err := del(ipt, "filter", "ts-input"); err != nil { - return err - } - if err := del(ipt, "filter", "ts-forward"); err != nil { - return err - } - } - if err := del(r.ipt4, "nat", "ts-postrouting"); err != nil { - return err - } - if r.v6NATAvailable { - if err := del(r.ipt6, "nat", "ts-postrouting"); err != nil { - return err - } - } - - return nil -} - -// addNetfilterHooks inserts calls to tailscale's netfilter chains in -// the relevant main netfilter chains. The tailscale chains must -// already exist. -func (r *linuxRouter) addNetfilterHooks() error { - divert := func(ipt netfilterRunner, table, chain string) error { - tsChain := tsChain(chain) - - args := []string{"-j", tsChain} - exists, err := ipt.Exists(table, chain, args...) - if err != nil { - return fmt.Errorf("checking for %v in %s/%s: %w", args, table, chain, err) - } - if exists { - return nil - } - if err := ipt.Insert(table, chain, 1, args...); err != nil { - return fmt.Errorf("adding %v in %s/%s: %w", args, table, chain, err) - } - return nil - } - - for _, ipt := range r.netfilterFamilies() { - if err := divert(ipt, "filter", "INPUT"); err != nil { - return err - } - if err := divert(ipt, "filter", "FORWARD"); err != nil { - return err - } - } - if err := divert(r.ipt4, "nat", "POSTROUTING"); err != nil { - return err - } - if r.v6NATAvailable { - if err := divert(r.ipt6, "nat", "POSTROUTING"); err != nil { - return err - } - } - return nil -} - -// delNetfilterHooks deletes the calls to tailscale's netfilter chains -// in the relevant main netfilter chains. -func (r *linuxRouter) delNetfilterHooks() error { - del := func(ipt netfilterRunner, table, chain string) error { - tsChain := tsChain(chain) - args := []string{"-j", tsChain} - if err := ipt.Delete(table, chain, args...); err != nil { - // TODO(apenwarr): check for errCode(1) here. - // Unfortunately the error code from the iptables - // module resists unwrapping, unlike with other - // calls. So we have to assume if Delete fails, - // it's because there is no such rule. - r.logf("note: deleting %v in %s/%s: %w", args, table, chain, err) - return nil - } - return nil - } - - for _, ipt := range r.netfilterFamilies() { - if err := del(ipt, "filter", "INPUT"); err != nil { - return err - } - if err := del(ipt, "filter", "FORWARD"); err != nil { - return err - } - } - if err := del(r.ipt4, "nat", "POSTROUTING"); err != nil { - return err - } - if r.v6NATAvailable { - if err := del(r.ipt6, "nat", "POSTROUTING"); err != nil { - return err - } - } - return nil -} - // addSNATRule adds a netfilter rule to SNAT traffic destined for // local subnets. func (r *linuxRouter) addSNATRule() error { @@ -1524,14 +1301,8 @@ func (r *linuxRouter) addSNATRule() error { return nil } - args := []string{"-m", "mark", "--mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask, "-j", "MASQUERADE"} - if err := r.ipt4.Append("nat", "ts-postrouting", args...); err != nil { - return fmt.Errorf("adding %v in v4/nat/ts-postrouting: %w", args, err) - } - if r.v6NATAvailable { - if err := r.ipt6.Append("nat", "ts-postrouting", args...); err != nil { - return fmt.Errorf("adding %v in v6/nat/ts-postrouting: %w", args, err) - } + if err := r.nfr.AddSNATRule(); err != nil { + return err } return nil } @@ -1543,14 +1314,8 @@ func (r *linuxRouter) delSNATRule() error { return nil } - args := []string{"-m", "mark", "--mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask, "-j", "MASQUERADE"} - if err := r.ipt4.Delete("nat", "ts-postrouting", args...); err != nil { - return fmt.Errorf("deleting %v in v4/nat/ts-postrouting: %w", args, err) - } - if r.v6NATAvailable { - if err := r.ipt6.Delete("nat", "ts-postrouting", args...); err != nil { - return fmt.Errorf("deleting %v in v6/nat/ts-postrouting: %w", args, err) - } + if err := r.nfr.DelSNATRule(); err != nil { + return err } return nil } @@ -1619,118 +1384,20 @@ func cidrDiff(kind string, old map[netip.Prefix]bool, new []netip.Prefix, add, d return ret, nil } -// tsChain returns the name of the tailscale sub-chain corresponding -// to the given "parent" chain (e.g. INPUT, FORWARD, ...). -func tsChain(chain string) string { - return "ts-" + strings.ToLower(chain) -} - // normalizeCIDR returns cidr as an ip/mask string, with the host bits // of the IP address zeroed out. func normalizeCIDR(cidr netip.Prefix) string { return cidr.Masked().String() } +// cleanup removes all the rules and routes that were added by the linux router. +// The function calls cleanup for both iptables and nftables since which ever +// netfilter runner is used, the cleanup function for the other one doesn't do anything. func cleanup(logf logger.Logf, interfaceName string) { - // TODO(dmytro): clean up iptables. -} - -// checkIPv6 checks whether the system appears to have a working IPv6 -// network stack. It returns an error explaining what looks wrong or -// missing. It does not check that IPv6 is currently functional or -// that there's a global address, just that the system would support -// IPv6 if it were on an IPv6 network. -func checkIPv6(logf logger.Logf) error { - _, err := os.Stat("/proc/sys/net/ipv6") - if os.IsNotExist(err) { - return err - } - bs, err := os.ReadFile("/proc/sys/net/ipv6/conf/all/disable_ipv6") - if err != nil { - // Be conservative if we can't find the ipv6 configuration knob. - return err - } - disabled, err := strconv.ParseBool(strings.TrimSpace(string(bs))) - if err != nil { - return errors.New("disable_ipv6 has invalid bool") + if interfaceName != "userspace-networking" { + linuxfw.IPTablesCleanup(logf) + linuxfw.NfTablesCleanUp(logf) } - if disabled { - return errors.New("disable_ipv6 is set") - } - - // Older kernels don't support IPv6 policy routing. Some kernels - // support policy routing but don't have this knob, so absence of - // the knob is not fatal. - bs, err = os.ReadFile("/proc/sys/net/ipv6/conf/all/disable_policy") - if err == nil { - disabled, err = strconv.ParseBool(strings.TrimSpace(string(bs))) - if err != nil { - return errors.New("disable_policy has invalid bool") - } - if disabled { - return errors.New("disable_policy is set") - } - } - - if err := checkIPRuleSupportsV6(logf); err != nil { - return fmt.Errorf("kernel doesn't support IPv6 policy routing: %w", err) - } - - // Some distros ship ip6tables separately from iptables. - if _, err := exec.LookPath("ip6tables"); err != nil { - return err - } - - return nil -} - -// supportsV6NAT returns whether the system has a "nat" table in the -// IPv6 netfilter stack. -// -// The nat table was added after the initial release of ipv6 -// netfilter, so some older distros ship a kernel that can't NAT IPv6 -// traffic. -func supportsV6NAT() bool { - bs, err := os.ReadFile("/proc/net/ip6_tables_names") - if err != nil { - // Can't read the file. Assume SNAT works. - return true - } - if bytes.Contains(bs, []byte("nat\n")) { - return true - } - // In nftables mode, that proc file will be empty. Try another thing: - if exec.Command("modprobe", "ip6table_nat").Run() == nil { - return true - } - return false -} - -func checkIPRuleSupportsV6(logf logger.Logf) error { - // First try just a read-only operation to ideally avoid - // having to modify any state. - if rules, err := netlink.RuleList(netlink.FAMILY_V6); err != nil { - return fmt.Errorf("querying IPv6 policy routing rules: %w", err) - } else { - if len(rules) > 0 { - logf("[v1] kernel supports IPv6 policy routing (found %d rules)", len(rules)) - return nil - } - } - - // Try to actually create & delete one as a test. - rule := netlink.NewRule() - rule.Priority = 1234 - rule.Mark = tailscaleBypassMarkNum - rule.Table = tailscaleRouteTable.num - rule.Family = netlink.FAMILY_V6 - // First delete the rule unconditionally, and don't check for - // errors. This is just cleaning up anything that might be already - // there. - netlink.RuleDel(rule) - // And clean up on exit. - defer netlink.RuleDel(rule) - return netlink.RuleAdd(rule) } // Checks if the running openWRT system is using mwan3, based on the heuristic diff --git a/vendor/tailscale.com/wgengine/router/router_openbsd.go b/vendor/tailscale.com/wgengine/router/router_openbsd.go index c23d37e474..b859927799 100644 --- a/vendor/tailscale.com/wgengine/router/router_openbsd.go +++ b/vendor/tailscale.com/wgengine/router/router_openbsd.go @@ -14,6 +14,7 @@ import ( "go4.org/netipx" "tailscale.com/net/netmon" "tailscale.com/types/logger" + "tailscale.com/util/set" ) // For now this router only supports the WireGuard userspace implementation. @@ -26,7 +27,7 @@ type openbsdRouter struct { tunname string local4 netip.Prefix local6 netip.Prefix - routes map[netip.Prefix]struct{} + routes set.Set[netip.Prefix] } func newUserspaceRouter(logf logger.Logf, tundev tun.Device, netMon *netmon.Monitor) (Router, error) { @@ -173,9 +174,9 @@ func (r *openbsdRouter) Set(cfg *Config) error { } } - newRoutes := make(map[netip.Prefix]struct{}) + newRoutes := set.Set[netip.Prefix]{} for _, route := range cfg.Routes { - newRoutes[route] = struct{}{} + newRoutes.Add(route) } for route := range r.routes { if _, keep := newRoutes[route]; !keep { diff --git a/vendor/tailscale.com/wgengine/userspace.go b/vendor/tailscale.com/wgengine/userspace.go index e29e20daec..397348ee25 100644 --- a/vendor/tailscale.com/wgengine/userspace.go +++ b/vendor/tailscale.com/wgengine/userspace.go @@ -19,14 +19,12 @@ import ( "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun" - "golang.org/x/exp/maps" - "tailscale.com/control/controlclient" + "tailscale.com/control/controlknobs" "tailscale.com/envknob" "tailscale.com/health" "tailscale.com/ipn/ipnstate" "tailscale.com/net/dns" "tailscale.com/net/flowtrack" - "tailscale.com/net/interfaces" "tailscale.com/net/netmon" "tailscale.com/net/packet" "tailscale.com/net/sockstats" @@ -42,9 +40,11 @@ import ( "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netmap" + "tailscale.com/types/views" "tailscale.com/util/clientmetric" "tailscale.com/util/deephash" "tailscale.com/util/mak" + "tailscale.com/util/set" "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/magicsock" @@ -95,9 +95,10 @@ type userspaceEngine struct { dns *dns.Manager magicConn *magicsock.Conn netMon *netmon.Monitor - netMonOwned bool // whether we created netMon (and thus need to close it) - netMonUnregister func() // unsubscribes from changes; used regardless of netMonOwned - birdClient BIRDClient // or nil + netMonOwned bool // whether we created netMon (and thus need to close it) + netMonUnregister func() // unsubscribes from changes; used regardless of netMonOwned + birdClient BIRDClient // or nil + controlKnobs *controlknobs.Knobs // or nil testMaybeReconfigHook func() // for tests; if non-nil, fires if maybeReconfigWireguardLocked called @@ -125,15 +126,13 @@ type userspaceEngine struct { statusBufioReader *bufio.Reader // reusable for UAPI lastStatusPollTime mono.Time // last time we polled the engine status - mu sync.Mutex // guards following; see lock order comment below - netMap *netmap.NetworkMap // or nil - closing bool // Close was called (even if we're still closing) - statusCallback StatusCallback - peerSequence []key.NodePublic - endpoints []tailcfg.Endpoint - pendOpen map[flowtrack.Tuple]*pendingOpenFlow // see pendopen.go - networkMapCallbacks map[*someHandle]NetworkMapCallback - tsIPByIPPort map[netip.AddrPort]netip.Addr // allows registration of IP:ports as belonging to a certain Tailscale IP for whois lookups + mu sync.Mutex // guards following; see lock order comment below + netMap *netmap.NetworkMap // or nil + closing bool // Close was called (even if we're still closing) + statusCallback StatusCallback + peerSequence []key.NodePublic + endpoints []tailcfg.Endpoint + pendOpen map[flowtrack.Tuple]*pendingOpenFlow // see pendopen.go // pongCallback is the map of response handlers waiting for disco or TSMP // pong callbacks. The map key is a random slice of bytes. @@ -183,6 +182,11 @@ type Config struct { // If nil, a new Dialer is created Dialer *tsdial.Dialer + // ControlKnobs is the set of control plane-provied knobs + // to use. + // If nil, defaults are used. + ControlKnobs *controlknobs.Knobs + // ListenPort is the port on which the engine will listen. // If zero, a port is automatically selected. ListenPort uint16 @@ -220,6 +224,8 @@ func NewFakeUserspaceEngine(logf logger.Logf, opts ...any) (Engine, error) { conf.ListenPort = uint16(v) case func(any): conf.SetSubsystem = v + case *controlknobs.Knobs: + conf.ControlKnobs = v default: return nil, fmt.Errorf("unknown option type %T", v) } @@ -271,6 +277,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) router: conf.Router, confListenPort: conf.ListenPort, birdClient: conf.BIRDClient, + controlKnobs: conf.ControlKnobs, } if e.birdClient != nil { @@ -279,8 +286,8 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) return nil, err } } - e.isLocalAddr.Store(tsaddr.NewContainsIPFunc(nil)) - e.isDNSIPOverTailscale.Store(tsaddr.NewContainsIPFunc(nil)) + e.isLocalAddr.Store(tsaddr.FalseContainsIPFunc()) + e.isDNSIPOverTailscale.Store(tsaddr.FalseContainsIPFunc()) if conf.NetMon != nil { e.netMon = conf.NetMon @@ -304,9 +311,9 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) logf("link state: %+v", e.netMon.InterfaceState()) - unregisterMonWatch := e.netMon.RegisterChangeCallback(func(changed bool, st *interfaces.State) { + unregisterMonWatch := e.netMon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) { tshttpproxy.InvalidateCache() - e.linkChange(changed, st) + e.linkChange(delta) }) closePool.addFunc(unregisterMonWatch) e.netMonUnregister = unregisterMonWatch @@ -326,6 +333,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) IdleFunc: e.tundev.IdleDuration, NoteRecvActivity: e.noteRecvActivity, NetMon: e.netMon, + ControlKnobs: conf.ControlKnobs, } var err error @@ -493,15 +501,12 @@ var debugTrimWireguard = envknob.RegisterOptBool("TS_DEBUG_TRIM_WIREGUARD") // That's sad too. Or we get rid of these knobs (lazy wireguard config has been // stable!) but I'm worried that a future regression would be easier to debug // with these knobs in place. -func forceFullWireguardConfig(numPeers int) bool { +func (e *userspaceEngine) forceFullWireguardConfig(numPeers int) bool { // Did the user explicitly enable trimming via the environment variable knob? if b, ok := debugTrimWireguard().Get(); ok { return !b } - if opt := controlclient.TrimWGConfig(); opt != "" { - return !opt.EqualBool(true) - } - return false + return e.controlKnobs != nil && e.controlKnobs.KeepFullWGConfig.Load() } // isTrimmablePeer reports whether p is a peer that we can trim out of the @@ -511,8 +516,8 @@ func forceFullWireguardConfig(numPeers int) bool { // only non-subnet AllowedIPs (an IPv4 /32 or IPv6 /128), which is the // common case for most peers. Subnet router nodes will just always be // created in the wireguard-go config. -func isTrimmablePeer(p *wgcfg.Peer, numPeers int) bool { - if forceFullWireguardConfig(numPeers) { +func (e *userspaceEngine) isTrimmablePeer(p *wgcfg.Peer, numPeers int) bool { + if e.forceFullWireguardConfig(numPeers) { return false } @@ -619,7 +624,7 @@ func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[key.Node // Don't re-alloc the map; the Go compiler optimizes map clears as of // Go 1.11, so we can re-use the existing + allocated map. if e.trimmedNodes != nil { - maps.Clear(e.trimmedNodes) + clear(e.trimmedNodes) } else { e.trimmedNodes = make(map[key.NodePublic]bool) } @@ -628,7 +633,7 @@ func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[key.Node for i := range full.Peers { p := &full.Peers[i] nk := p.PublicKey - if !isTrimmablePeer(p, len(full.Peers)) { + if !e.isTrimmablePeer(p, len(full.Peers)) { min.Peers = append(min.Peers, *p) if discoChanged[nk] { needRemoveStep = true @@ -759,18 +764,17 @@ func (e *userspaceEngine) updateActivityMapsLocked(trackNodes []key.NodePublic, // hasOverlap checks if there is a IPPrefix which is common amongst the two // provided slices. -func hasOverlap(aips, rips []netip.Prefix) bool { - for _, aip := range aips { - for _, rip := range rips { - if aip == rip { - return true - } +func hasOverlap(aips, rips views.Slice[netip.Prefix]) bool { + for i := range aips.LenIter() { + aip := aips.At(i) + if views.SliceContains(rips, aip) { + return true } } return false } -func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *dns.Config, debug *tailcfg.Debug) error { +func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *dns.Config) error { if routerCfg == nil { panic("routerCfg must not be nil") } @@ -778,33 +782,35 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, panic("dnsCfg must not be nil") } - e.isLocalAddr.Store(tsaddr.NewContainsIPFunc(routerCfg.LocalAddrs)) + e.isLocalAddr.Store(tsaddr.NewContainsIPFunc(views.SliceOf(routerCfg.LocalAddrs))) e.wgLock.Lock() defer e.wgLock.Unlock() e.tundev.SetWGConfig(cfg) e.lastDNSConfig = dnsCfg - peerSet := make(map[key.NodePublic]struct{}, len(cfg.Peers)) + peerSet := make(set.Set[key.NodePublic], len(cfg.Peers)) e.mu.Lock() e.peerSequence = e.peerSequence[:0] for _, p := range cfg.Peers { e.peerSequence = append(e.peerSequence, p.PublicKey) - peerSet[p.PublicKey] = struct{}{} + peerSet.Add(p.PublicKey) } nm := e.netMap e.mu.Unlock() listenPort := e.confListenPort - if debug != nil && debug.RandomizeClientPort { + if e.controlKnobs != nil && e.controlKnobs.RandomizeClientPort.Load() { listenPort = 0 } + peerMTUEnable := e.magicConn.ShouldPMTUD() + isSubnetRouter := false - if e.birdClient != nil && nm != nil && nm.SelfNode != nil { - isSubnetRouter = hasOverlap(nm.SelfNode.PrimaryRoutes, nm.Hostinfo.RoutableIPs) + if e.birdClient != nil && nm != nil && nm.SelfNode.Valid() { + isSubnetRouter = hasOverlap(nm.SelfNode.PrimaryRoutes(), nm.SelfNode.Hostinfo().RoutableIPs()) e.logf("[v1] Reconfig: hasOverlap(%v, %v) = %v; isSubnetRouter=%v lastIsSubnetRouter=%v", - nm.SelfNode.PrimaryRoutes, nm.Hostinfo.RoutableIPs, + nm.SelfNode.PrimaryRoutes(), nm.SelfNode.Hostinfo().RoutableIPs(), isSubnetRouter, isSubnetRouter, e.lastIsSubnetRouter) } isSubnetRouterChanged := isSubnetRouter != e.lastIsSubnetRouter @@ -814,7 +820,9 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, RouterConfig *router.Config DNSConfig *dns.Config }{routerCfg, dnsCfg}) - if !engineChanged && !routerChanged && listenPort == e.magicConn.LocalPort() && !isSubnetRouterChanged { + listenPortChanged := listenPort != e.magicConn.LocalPort() + peerMTUChanged := peerMTUEnable != e.magicConn.PeerMTUEnabled() + if !engineChanged && !routerChanged && !listenPortChanged && !isSubnetRouterChanged && !peerMTUChanged { return ErrNoChanges } newLogIDs := cfg.NetworkLogging @@ -832,7 +840,7 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, // instead have ipnlocal populate a map of DNS IP => linkName and // put that in the *dns.Config instead, and plumb it down to the // dns.Manager. Maybe also with isLocalAddr above. - e.isDNSIPOverTailscale.Store(tsaddr.NewContainsIPFunc(dnsIPsOverTailscale(dnsCfg, routerCfg))) + e.isDNSIPOverTailscale.Store(tsaddr.NewContainsIPFunc(views.SliceOf(dnsIPsOverTailscale(dnsCfg, routerCfg)))) // See if any peers have changed disco keys, which means they've restarted. // If so, we need to update the wireguard-go/device.Device in two phases: @@ -870,6 +878,7 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, } e.magicConn.UpdatePeers(peerSet) e.magicConn.SetPreferredPort(listenPort) + e.magicConn.UpdatePMTUD() if err := e.maybeReconfigWireguardLocked(discoChanged); err != nil { return err @@ -1094,14 +1103,9 @@ func (e *userspaceEngine) Wait() { <-e.waitCh } -// LinkChange signals a network change event. It's currently -// (2021-03-03) only called on Android. On other platforms, netMon -// generates link change events for us. -func (e *userspaceEngine) LinkChange(_ bool) { - e.netMon.InjectEvent() -} - -func (e *userspaceEngine) linkChange(changed bool, cur *interfaces.State) { +func (e *userspaceEngine) linkChange(delta *netmon.ChangeDelta) { + changed := delta.Major // TODO(bradfitz): ask more specific questions? + cur := delta.New up := cur.AnyInterfaceUp() if !up { e.logf("LinkChange: all links down; pausing: %v", cur) @@ -1150,45 +1154,11 @@ func (e *userspaceEngine) linkChange(changed bool, cur *interfaces.State) { e.magicConn.ReSTUN(why) } -func (e *userspaceEngine) AddNetworkMapCallback(cb NetworkMapCallback) func() { - e.mu.Lock() - defer e.mu.Unlock() - if e.networkMapCallbacks == nil { - e.networkMapCallbacks = make(map[*someHandle]NetworkMapCallback) - } - h := new(someHandle) - e.networkMapCallbacks[h] = cb - return func() { - e.mu.Lock() - defer e.mu.Unlock() - delete(e.networkMapCallbacks, h) - } -} - -func (e *userspaceEngine) SetNetInfoCallback(cb NetInfoCallback) { - e.magicConn.SetNetInfoCallback(cb) -} - -func (e *userspaceEngine) SetDERPMap(dm *tailcfg.DERPMap) { - e.magicConn.SetDERPMap(dm) -} - func (e *userspaceEngine) SetNetworkMap(nm *netmap.NetworkMap) { e.magicConn.SetNetworkMap(nm) e.mu.Lock() e.netMap = nm - callbacks := make([]NetworkMapCallback, 0, 4) - for _, fn := range e.networkMapCallbacks { - callbacks = append(callbacks, fn) - } e.mu.Unlock() - for _, fn := range callbacks { - fn(nm) - } -} - -func (e *userspaceEngine) DiscoPublicKey() key.DiscoPublic { - return e.magicConn.DiscoPublicKey() } func (e *userspaceEngine) UpdateStatus(sb *ipnstate.StatusBuilder) { @@ -1211,7 +1181,7 @@ func (e *userspaceEngine) UpdateStatus(sb *ipnstate.StatusBuilder) { e.magicConn.UpdateStatus(sb) } -func (e *userspaceEngine) Ping(ip netip.Addr, pingType tailcfg.PingType, cb func(*ipnstate.PingResult)) { +func (e *userspaceEngine) Ping(ip netip.Addr, pingType tailcfg.PingType, size int, cb func(*ipnstate.PingResult)) { res := &ipnstate.PingResult{IP: ip.String()} pip, ok := e.PeerForIP(ip) if !ok { @@ -1228,10 +1198,10 @@ func (e *userspaceEngine) Ping(ip netip.Addr, pingType tailcfg.PingType, cb func } peer := pip.Node - e.logf("ping(%v): sending %v ping to %v %v ...", ip, pingType, peer.Key.ShortString(), peer.ComputedName) + e.logf("ping(%v): sending %v ping to %v %v ...", ip, pingType, peer.Key().ShortString(), peer.ComputedName()) switch pingType { case "disco": - e.magicConn.Ping(peer, res, cb) + e.magicConn.Ping(peer, res, size, cb) case "TSMP": e.sendTSMPPing(ip, peer, res, cb) case "ICMP": @@ -1240,23 +1210,25 @@ func (e *userspaceEngine) Ping(ip netip.Addr, pingType tailcfg.PingType, cb func } func (e *userspaceEngine) mySelfIPMatchingFamily(dst netip.Addr) (src netip.Addr, err error) { + var zero netip.Addr e.mu.Lock() defer e.mu.Unlock() if e.netMap == nil { - return netip.Addr{}, errors.New("no netmap") + return zero, errors.New("no netmap") + } + addrs := e.netMap.GetAddresses() + if addrs.Len() == 0 { + return zero, errors.New("no self address in netmap") } - for _, a := range e.netMap.Addresses { - if a.IsSingleIP() && a.Addr().BitLen() == dst.BitLen() { + for i := range addrs.LenIter() { + if a := addrs.At(i); a.IsSingleIP() && a.Addr().BitLen() == dst.BitLen() { return a.Addr(), nil } } - if len(e.netMap.Addresses) == 0 { - return netip.Addr{}, errors.New("no self address in netmap") - } - return netip.Addr{}, errors.New("no self address in netmap matching address family") + return zero, errors.New("no self address in netmap matching address family") } -func (e *userspaceEngine) sendICMPEchoRequest(destIP netip.Addr, peer *tailcfg.Node, res *ipnstate.PingResult, cb func(*ipnstate.PingResult)) { +func (e *userspaceEngine) sendICMPEchoRequest(destIP netip.Addr, peer tailcfg.NodeView, res *ipnstate.PingResult, cb func(*ipnstate.PingResult)) { srcIP, err := e.mySelfIPMatchingFamily(destIP) if err != nil { res.Err = err.Error() @@ -1297,7 +1269,7 @@ func (e *userspaceEngine) sendICMPEchoRequest(destIP netip.Addr, peer *tailcfg.N d := time.Since(t0) res.LatencySeconds = d.Seconds() res.NodeIP = destIP.String() - res.NodeName = peer.ComputedName + res.NodeName = peer.ComputedName() cb(res) }) @@ -1305,7 +1277,7 @@ func (e *userspaceEngine) sendICMPEchoRequest(destIP netip.Addr, peer *tailcfg.N e.tundev.InjectOutbound(icmpPing) } -func (e *userspaceEngine) sendTSMPPing(ip netip.Addr, peer *tailcfg.Node, res *ipnstate.PingResult, cb func(*ipnstate.PingResult)) { +func (e *userspaceEngine) sendTSMPPing(ip netip.Addr, peer tailcfg.NodeView, res *ipnstate.PingResult, cb func(*ipnstate.PingResult)) { srcIP, err := e.mySelfIPMatchingFamily(ip) if err != nil { res.Err = err.Error() @@ -1339,7 +1311,7 @@ func (e *userspaceEngine) sendTSMPPing(ip netip.Addr, peer *tailcfg.Node, res *i d := time.Since(t0) res.LatencySeconds = d.Seconds() res.NodeIP = ip.String() - res.NodeName = peer.ComputedName + res.NodeName = peer.ComputedName() res.PeerAPIPort = pong.PeerAPIPort cb(res) }) @@ -1375,50 +1347,6 @@ func (e *userspaceEngine) setICMPEchoResponseCallback(idSeq uint32, cb func()) { } } -func (e *userspaceEngine) RegisterIPPortIdentity(ipport netip.AddrPort, tsIP netip.Addr) { - e.mu.Lock() - defer e.mu.Unlock() - if e.tsIPByIPPort == nil { - e.tsIPByIPPort = make(map[netip.AddrPort]netip.Addr) - } - e.tsIPByIPPort[ipport] = tsIP -} - -func (e *userspaceEngine) UnregisterIPPortIdentity(ipport netip.AddrPort) { - e.mu.Lock() - defer e.mu.Unlock() - if e.tsIPByIPPort == nil { - return - } - delete(e.tsIPByIPPort, ipport) -} - -var whoIsSleeps = [...]time.Duration{ - 0, - 10 * time.Millisecond, - 20 * time.Millisecond, - 50 * time.Millisecond, - 100 * time.Millisecond, -} - -func (e *userspaceEngine) WhoIsIPPort(ipport netip.AddrPort) (tsIP netip.Addr, ok bool) { - // We currently have a registration race, - // https://github.com/tailscale/tailscale/issues/1616, - // so loop a few times for now waiting for the registration - // to appear. - // TODO(bradfitz,namansood): remove this once #1616 is fixed. - for _, d := range whoIsSleeps { - time.Sleep(d) - e.mu.Lock() - tsIP, ok = e.tsIPByIPPort[ipport] - e.mu.Unlock() - if ok { - return tsIP, true - } - } - return tsIP, false -} - // PeerForIP returns the Node in the wireguard config // that's responsible for handling the given IP address. // @@ -1438,14 +1366,16 @@ func (e *userspaceEngine) PeerForIP(ip netip.Addr) (ret PeerForIP, ok bool) { // Check for exact matches before looking for subnet matches. // TODO(bradfitz): add maps for these. on NetworkMap? for _, p := range nm.Peers { - for _, a := range p.Addresses { + for i := range p.Addresses().LenIter() { + a := p.Addresses().At(i) if a.Addr() == ip && a.IsSingleIP() && tsaddr.IsTailscaleIP(ip) { return PeerForIP{Node: p, Route: a}, true } } } - for _, a := range nm.Addresses { - if a.Addr() == ip && a.IsSingleIP() && tsaddr.IsTailscaleIP(ip) { + addrs := nm.GetAddresses() + for i := range addrs.LenIter() { + if a := addrs.At(i); a.Addr() == ip && a.IsSingleIP() && tsaddr.IsTailscaleIP(ip) { return PeerForIP{Node: nm.SelfNode, IsSelf: true, Route: a}, true } } @@ -1471,7 +1401,7 @@ func (e *userspaceEngine) PeerForIP(ip netip.Addr) (ret PeerForIP, ok bool) { // call. But TODO(bradfitz): add a lookup map to netmap.NetworkMap. if !bestKey.IsZero() { for _, p := range nm.Peers { - if p.Key == bestKey { + if p.Key() == bestKey { return PeerForIP{Node: p, Route: best}, true } } diff --git a/vendor/tailscale.com/wgengine/watchdog.go b/vendor/tailscale.com/wgengine/watchdog.go index 19505be896..75bc1c0ead 100644 --- a/vendor/tailscale.com/wgengine/watchdog.go +++ b/vendor/tailscale.com/wgengine/watchdog.go @@ -18,7 +18,6 @@ import ( "tailscale.com/ipn/ipnstate" "tailscale.com/net/dns" "tailscale.com/tailcfg" - "tailscale.com/types/key" "tailscale.com/types/netmap" "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" @@ -119,8 +118,8 @@ func (e *watchdogEngine) watchdog(name string, fn func()) { }) } -func (e *watchdogEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *dns.Config, debug *tailcfg.Debug) error { - return e.watchdogErr("Reconfig", func() error { return e.wrap.Reconfig(cfg, routerCfg, dnsCfg, debug) }) +func (e *watchdogEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *dns.Config) error { + return e.watchdogErr("Reconfig", func() error { return e.wrap.Reconfig(cfg, routerCfg, dnsCfg) }) } func (e *watchdogEngine) GetFilter() *filter.Filter { return e.wrap.GetFilter() @@ -134,42 +133,14 @@ func (e *watchdogEngine) SetStatusCallback(cb StatusCallback) { func (e *watchdogEngine) UpdateStatus(sb *ipnstate.StatusBuilder) { e.watchdog("UpdateStatus", func() { e.wrap.UpdateStatus(sb) }) } -func (e *watchdogEngine) SetNetInfoCallback(cb NetInfoCallback) { - e.watchdog("SetNetInfoCallback", func() { e.wrap.SetNetInfoCallback(cb) }) -} func (e *watchdogEngine) RequestStatus() { e.watchdog("RequestStatus", func() { e.wrap.RequestStatus() }) } -func (e *watchdogEngine) LinkChange(isExpensive bool) { - e.watchdog("LinkChange", func() { e.wrap.LinkChange(isExpensive) }) -} -func (e *watchdogEngine) SetDERPMap(m *tailcfg.DERPMap) { - e.watchdog("SetDERPMap", func() { e.wrap.SetDERPMap(m) }) -} func (e *watchdogEngine) SetNetworkMap(nm *netmap.NetworkMap) { e.watchdog("SetNetworkMap", func() { e.wrap.SetNetworkMap(nm) }) } -func (e *watchdogEngine) AddNetworkMapCallback(callback NetworkMapCallback) func() { - var fn func() - e.watchdog("AddNetworkMapCallback", func() { fn = e.wrap.AddNetworkMapCallback(callback) }) - return func() { e.watchdog("RemoveNetworkMapCallback", fn) } -} -func (e *watchdogEngine) DiscoPublicKey() (k key.DiscoPublic) { - e.watchdog("DiscoPublicKey", func() { k = e.wrap.DiscoPublicKey() }) - return k -} -func (e *watchdogEngine) Ping(ip netip.Addr, pingType tailcfg.PingType, cb func(*ipnstate.PingResult)) { - e.watchdog("Ping", func() { e.wrap.Ping(ip, pingType, cb) }) -} -func (e *watchdogEngine) RegisterIPPortIdentity(ipp netip.AddrPort, tsIP netip.Addr) { - e.watchdog("RegisterIPPortIdentity", func() { e.wrap.RegisterIPPortIdentity(ipp, tsIP) }) -} -func (e *watchdogEngine) UnregisterIPPortIdentity(ipp netip.AddrPort) { - e.watchdog("UnregisterIPPortIdentity", func() { e.wrap.UnregisterIPPortIdentity(ipp) }) -} -func (e *watchdogEngine) WhoIsIPPort(ipp netip.AddrPort) (tsIP netip.Addr, ok bool) { - e.watchdog("UnregisterIPPortIdentity", func() { tsIP, ok = e.wrap.WhoIsIPPort(ipp) }) - return tsIP, ok +func (e *watchdogEngine) Ping(ip netip.Addr, pingType tailcfg.PingType, size int, cb func(*ipnstate.PingResult)) { + e.watchdog("Ping", func() { e.wrap.Ping(ip, pingType, size, cb) }) } func (e *watchdogEngine) Close() { e.watchdog("Close", e.wrap.Close) diff --git a/vendor/tailscale.com/wgengine/wgcfg/config.go b/vendor/tailscale.com/wgengine/wgcfg/config.go index 18f019b536..a6a130b6fd 100644 --- a/vendor/tailscale.com/wgengine/wgcfg/config.go +++ b/vendor/tailscale.com/wgengine/wgcfg/config.go @@ -38,6 +38,7 @@ type Peer struct { DiscoKey key.DiscoPublic // present only so we can handle restarts within wgengine, not passed to WireGuard AllowedIPs []netip.Prefix V4MasqAddr *netip.Addr // if non-nil, masquerade IPv4 traffic to this peer using this address + V6MasqAddr *netip.Addr // if non-nil, masquerade IPv6 traffic to this peer using this address PersistentKeepalive uint16 // wireguard-go's endpoint for this peer. It should always equal Peer.PublicKey. // We represent it explicitly so that we can detect if they diverge and recover. diff --git a/vendor/tailscale.com/wgengine/wgcfg/nmcfg/nmcfg.go b/vendor/tailscale.com/wgengine/wgcfg/nmcfg/nmcfg.go index f01b42cb1e..885d507fab 100644 --- a/vendor/tailscale.com/wgengine/wgcfg/nmcfg/nmcfg.go +++ b/vendor/tailscale.com/wgengine/wgcfg/nmcfg/nmcfg.go @@ -10,7 +10,6 @@ import ( "net/netip" "strings" - "golang.org/x/exp/slices" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/logger" @@ -19,30 +18,31 @@ import ( "tailscale.com/wgengine/wgcfg" ) -func nodeDebugName(n *tailcfg.Node) string { - name := n.Name +func nodeDebugName(n tailcfg.NodeView) string { + name := n.Name() if name == "" { - name = n.Hostinfo.Hostname() + name = n.Hostinfo().Hostname() } if i := strings.Index(name, "."); i != -1 { name = name[:i] } - if name == "" && len(n.Addresses) != 0 { - return n.Addresses[0].String() + if name == "" && n.Addresses().Len() != 0 { + return n.Addresses().At(0).String() } return name } // cidrIsSubnet reports whether cidr is a non-default-route subnet // exported by node that is not one of its own self addresses. -func cidrIsSubnet(node *tailcfg.Node, cidr netip.Prefix) bool { +func cidrIsSubnet(node tailcfg.NodeView, cidr netip.Prefix) bool { if cidr.Bits() == 0 { return false } if !cidr.IsSingleIP() { return true } - for _, selfCIDR := range node.Addresses { + for i := range node.Addresses().LenIter() { + selfCIDR := node.Addresses().At(i) if cidr == selfCIDR { return false } @@ -55,16 +55,16 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, cfg := &wgcfg.Config{ Name: "tailscale", PrivateKey: nm.PrivateKey, - Addresses: nm.Addresses, + Addresses: nm.GetAddresses().AsSlice(), Peers: make([]wgcfg.Peer, 0, len(nm.Peers)), } // Setup log IDs for data plane audit logging. - if nm.SelfNode != nil { - cfg.NodeID = nm.SelfNode.StableID - canNetworkLog := slices.Contains(nm.SelfNode.Capabilities, tailcfg.CapabilityDataPlaneAuditLogs) - if canNetworkLog && nm.SelfNode.DataPlaneAuditLogID != "" && nm.DomainAuditLogID != "" { - nodeID, errNode := logid.ParsePrivateID(nm.SelfNode.DataPlaneAuditLogID) + if nm.SelfNode.Valid() { + cfg.NodeID = nm.SelfNode.StableID() + canNetworkLog := nm.SelfNode.HasCap(tailcfg.CapabilityDataPlaneAuditLogs) + if canNetworkLog && nm.SelfNode.DataPlaneAuditLogID() != "" && nm.DomainAuditLogID != "" { + nodeID, errNode := logid.ParsePrivateID(nm.SelfNode.DataPlaneAuditLogID()) if errNode != nil { logf("[v1] wgcfg: unable to parse node audit log ID: %v", errNode) } @@ -85,25 +85,24 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, skippedSubnets := new(bytes.Buffer) for _, peer := range nm.Peers { - if peer.DiscoKey.IsZero() && peer.DERP == "" && !peer.IsWireGuardOnly { + if peer.DiscoKey().IsZero() && peer.DERP() == "" && !peer.IsWireGuardOnly() { // Peer predates both DERP and active discovery, we cannot // communicate with it. - logf("[v1] wgcfg: skipped peer %s, doesn't offer DERP or disco", peer.Key.ShortString()) + logf("[v1] wgcfg: skipped peer %s, doesn't offer DERP or disco", peer.Key().ShortString()) continue } cfg.Peers = append(cfg.Peers, wgcfg.Peer{ - PublicKey: peer.Key, - DiscoKey: peer.DiscoKey, + PublicKey: peer.Key(), + DiscoKey: peer.DiscoKey(), }) cpeer := &cfg.Peers[len(cfg.Peers)-1] - if peer.KeepAlive { - cpeer.PersistentKeepalive = 25 // seconds - } didExitNodeWarn := false - cpeer.V4MasqAddr = peer.SelfNodeV4MasqAddrForThisPeer - for _, allowedIP := range peer.AllowedIPs { - if allowedIP.Bits() == 0 && peer.StableID != exitNode { + cpeer.V4MasqAddr = peer.SelfNodeV4MasqAddrForThisPeer() + cpeer.V6MasqAddr = peer.SelfNodeV6MasqAddrForThisPeer() + for i := range peer.AllowedIPs().LenIter() { + allowedIP := peer.AllowedIPs().At(i) + if allowedIP.Bits() == 0 && peer.StableID() != exitNode { if didExitNodeWarn { // Don't log about both the IPv4 /0 and IPv6 /0. continue @@ -112,20 +111,20 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, if skippedUnselected.Len() > 0 { skippedUnselected.WriteString(", ") } - fmt.Fprintf(skippedUnselected, "%q (%v)", nodeDebugName(peer), peer.Key.ShortString()) + fmt.Fprintf(skippedUnselected, "%q (%v)", nodeDebugName(peer), peer.Key().ShortString()) continue } else if allowedIP.IsSingleIP() && tsaddr.IsTailscaleIP(allowedIP.Addr()) && (flags&netmap.AllowSingleHosts) == 0 { if skippedIPs.Len() > 0 { skippedIPs.WriteString(", ") } - fmt.Fprintf(skippedIPs, "%v from %q (%v)", allowedIP.Addr(), nodeDebugName(peer), peer.Key.ShortString()) + fmt.Fprintf(skippedIPs, "%v from %q (%v)", allowedIP.Addr(), nodeDebugName(peer), peer.Key().ShortString()) continue } else if cidrIsSubnet(peer, allowedIP) { if (flags & netmap.AllowSubnetRoutes) == 0 { if skippedSubnets.Len() > 0 { skippedSubnets.WriteString(", ") } - fmt.Fprintf(skippedSubnets, "%v from %q (%v)", allowedIP, nodeDebugName(peer), peer.Key.ShortString()) + fmt.Fprintf(skippedSubnets, "%v from %q (%v)", allowedIP, nodeDebugName(peer), peer.Key().ShortString()) continue } } diff --git a/vendor/tailscale.com/wgengine/wgcfg/wgcfg_clone.go b/vendor/tailscale.com/wgengine/wgcfg/wgcfg_clone.go index 6887dd6cc4..4a2288f1ee 100644 --- a/vendor/tailscale.com/wgengine/wgcfg/wgcfg_clone.go +++ b/vendor/tailscale.com/wgengine/wgcfg/wgcfg_clone.go @@ -11,6 +11,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/logid" + "tailscale.com/types/ptr" ) // Clone makes a deep copy of Config. @@ -23,9 +24,11 @@ func (src *Config) Clone() *Config { *dst = *src dst.Addresses = append(src.Addresses[:0:0], src.Addresses...) dst.DNS = append(src.DNS[:0:0], src.DNS...) - dst.Peers = make([]Peer, len(src.Peers)) - for i := range dst.Peers { - dst.Peers[i] = *src.Peers[i].Clone() + if src.Peers != nil { + dst.Peers = make([]Peer, len(src.Peers)) + for i := range dst.Peers { + dst.Peers[i] = *src.Peers[i].Clone() + } } return dst } @@ -55,8 +58,10 @@ func (src *Peer) Clone() *Peer { *dst = *src dst.AllowedIPs = append(src.AllowedIPs[:0:0], src.AllowedIPs...) if dst.V4MasqAddr != nil { - dst.V4MasqAddr = new(netip.Addr) - *dst.V4MasqAddr = *src.V4MasqAddr + dst.V4MasqAddr = ptr.To(*src.V4MasqAddr) + } + if dst.V6MasqAddr != nil { + dst.V6MasqAddr = ptr.To(*src.V6MasqAddr) } return dst } @@ -67,6 +72,7 @@ var _PeerCloneNeedsRegeneration = Peer(struct { DiscoKey key.DiscoPublic AllowedIPs []netip.Prefix V4MasqAddr *netip.Addr + V6MasqAddr *netip.Addr PersistentKeepalive uint16 WGEndpoint key.NodePublic }{}) diff --git a/vendor/tailscale.com/wgengine/wgengine.go b/vendor/tailscale.com/wgengine/wgengine.go index df591c9e03..e21987f939 100644 --- a/vendor/tailscale.com/wgengine/wgengine.go +++ b/vendor/tailscale.com/wgengine/wgengine.go @@ -11,7 +11,6 @@ import ( "tailscale.com/ipn/ipnstate" "tailscale.com/net/dns" "tailscale.com/tailcfg" - "tailscale.com/types/key" "tailscale.com/types/netmap" "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" @@ -35,26 +34,18 @@ type Status struct { // Exactly one of Status or error is non-nil. type StatusCallback func(*Status, error) -// NetInfoCallback is the type used by Engine.SetNetInfoCallback. -type NetInfoCallback func(*tailcfg.NetInfo) - // NetworkMapCallback is the type used by callbacks that hook // into network map updates. type NetworkMapCallback func(*netmap.NetworkMap) -// someHandle is allocated so its pointer address acts as a unique -// map key handle. (It needs to have non-zero size for Go to guarantee -// the pointer is unique.) -type someHandle struct{ _ byte } - // ErrNoChanges is returned by Engine.Reconfig if no changes were made. var ErrNoChanges = errors.New("no changes made to Engine config") // PeerForIP is the type returned by Engine.PeerForIP. type PeerForIP struct { - // Node is the matched node. It's always non-nil when + // Node is the matched node. It's always a valid value when // Engine.PeerForIP returns ok==true. - Node *tailcfg.Node + Node tailcfg.NodeView // IsSelf is whether the Node is the local process. IsSelf bool @@ -72,10 +63,8 @@ type Engine interface { // This is called whenever tailcontrol (the control plane) // sends an updated network map. // - // The *tailcfg.Debug parameter can be nil. - // // The returned error is ErrNoChanges if no changes were made. - Reconfig(*wgcfg.Config, *router.Config, *dns.Config, *tailcfg.Debug) error + Reconfig(*wgcfg.Config, *router.Config, *dns.Config) error // PeerForIP returns the node to which the provided IP routes, // if any. If none is found, (nil, false) is returned. @@ -105,26 +94,6 @@ type Engine interface { // TODO: return an error? Wait() - // LinkChange informs the engine that the system network - // link has changed. - // - // The isExpensive parameter is not used. - // - // LinkChange should be called whenever something changed with - // the network, no matter how minor. - // - // Deprecated: don't use this method. It was removed shortly - // before the Tailscale 1.6 release when we remembered that - // Android doesn't use the Linux-based network monitor and has - // its own mechanism that uses LinkChange. Android is the only - // caller of this method now. Don't add more. - LinkChange(isExpensive bool) - - // SetDERPMap controls which (if any) DERP servers are used. - // If nil, DERP is disabled. It starts disabled until a DERP map - // is configured. - SetDERPMap(*tailcfg.DERPMap) - // SetNetworkMap informs the engine of the latest network map // from the server. The network map's DERPMap field should be // ignored as as it might be disabled; get it from SetDERPMap @@ -132,42 +101,15 @@ type Engine interface { // The network map should only be read from. SetNetworkMap(*netmap.NetworkMap) - // AddNetworkMapCallback adds a function to a list of callbacks - // that are called when the network map updates. It returns a - // function that when called would remove the function from the - // list of callbacks. - AddNetworkMapCallback(NetworkMapCallback) (removeCallback func()) - - // SetNetInfoCallback sets the function to call when a - // new NetInfo summary is available. - SetNetInfoCallback(NetInfoCallback) - - // DiscoPublicKey gets the public key used for path discovery - // messages. - DiscoPublicKey() key.DiscoPublic - // UpdateStatus populates the network state using the provided // status builder. UpdateStatus(*ipnstate.StatusBuilder) - // Ping is a request to start a ping with the peer handling the given IP and - // then call cb with its ping latency & method. - Ping(ip netip.Addr, pingType tailcfg.PingType, cb func(*ipnstate.PingResult)) - - // RegisterIPPortIdentity registers a given node (identified by its - // Tailscale IP) as temporarily having the given IP:port for whois lookups. - // The IP:port is generally a localhost IP and an ephemeral port, used - // while proxying connections to localhost when tailscaled is running - // in netstack mode. - RegisterIPPortIdentity(netip.AddrPort, netip.Addr) - - // UnregisterIPPortIdentity removes a temporary IP:port registration - // made previously by RegisterIPPortIdentity. - UnregisterIPPortIdentity(netip.AddrPort) - - // WhoIsIPPort looks up an IP:port in the temporary registrations, - // and returns a matching Tailscale IP, if it exists. - WhoIsIPPort(netip.AddrPort) (netip.Addr, bool) + // Ping is a request to start a ping of the given message size to the peer + // handling the given IP, then call cb with its ping latency & method. + // + // If size is zero too small, it is ignored. See tailscale.PingOpts for details. + Ping(ip netip.Addr, pingType tailcfg.PingType, size int, cb func(*ipnstate.PingResult)) // InstallCaptureHook registers a function to be called to capture // packets traversing the data path. The hook can be uninstalled by