summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-09-14 16:44:56 -0700
committergVisor bot <gvisor-bot@google.com>2021-09-14 16:47:19 -0700
commit8d14edb14b6b757f049faf760c72d58616903d7a (patch)
tree875ab242cd4593ace898265346b231b45633a3a7
parent603f473ada5f5cacc759c8810df01f47905ba5e9 (diff)
Explicitly bind endpoint to a NIC
Previously, any time a datagram-based network endpoint (e.g. UDP) was bound, the bound NIC is always set based on the bound address (if specified). However, we should only consider the endpoint bound to an NIC if a NIC was explicitly bound to. If an endpoint has been bound to an address and attempts to send packets to an unconnected remote, the endpoint will default to sending packets through the bound address' NIC if not explicitly bound to a NIC. Updates #6565. PiperOrigin-RevId: 396712415
-rw-r--r--pkg/tcpip/transport/internal/network/BUILD1
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint.go23
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_test.go112
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go2
-rw-r--r--test/syscalls/linux/udp_socket.cc61
5 files changed, 183 insertions, 16 deletions
diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD
index d10e3f13a..d6d3f52a3 100644
--- a/pkg/tcpip/transport/internal/network/BUILD
+++ b/pkg/tcpip/transport/internal/network/BUILD
@@ -32,6 +32,7 @@ go_test(
"//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go
index c5b575e1c..09b629022 100644
--- a/pkg/tcpip/transport/internal/network/endpoint.go
+++ b/pkg/tcpip/transport/internal/network/endpoint.go
@@ -44,8 +44,9 @@ type Endpoint struct {
state uint32
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- info stack.TransportEndpointInfo
+ mu sync.RWMutex `state:"nosave"`
+ wasBound bool
+ info stack.TransportEndpointInfo
// owner is the owner of transmitted packets.
owner tcpip.PacketOwner
writeShutdown bool
@@ -248,6 +249,9 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext
nicID = e.info.BindNICID
}
+ if nicID == 0 {
+ nicID = e.info.RegisterNICID
+ }
dst, netProto, err := e.checkV4MappedLocked(*opts.To)
if err != nil {
@@ -294,9 +298,9 @@ func (e *Endpoint) Disconnect() {
}
// Exclude ephemerally bound endpoints.
- if e.info.BindNICID != 0 || e.info.ID.LocalAddress == "" {
+ if e.wasBound {
e.info.ID = stack.TransportEndpointID{
- LocalAddress: e.info.ID.LocalAddress,
+ LocalAddress: e.info.BindAddr,
}
e.setEndpointState(transport.DatagramEndpointStateBound)
} else {
@@ -477,10 +481,12 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto
return err
}
+ e.wasBound = true
+
e.info.ID = stack.TransportEndpointID{
LocalAddress: addr.Addr,
}
- e.info.BindNICID = nicID
+ e.info.BindNICID = addr.NIC
e.info.RegisterNICID = nicID
e.info.BindAddr = addr.Addr
e.effectiveNetProto = netProto
@@ -488,6 +494,13 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto
return nil
}
+// WasBound returns true iff the endpoint was ever bound.
+func (e *Endpoint) WasBound() bool {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.wasBound
+}
+
// GetLocalAddress returns the address that the endpoint is bound to.
func (e *Endpoint) GetLocalAddress() tcpip.FullAddress {
e.mu.RLock()
diff --git a/pkg/tcpip/transport/internal/network/endpoint_test.go b/pkg/tcpip/transport/internal/network/endpoint_test.go
index 2c43eb66a..d99c961c3 100644
--- a/pkg/tcpip/transport/internal/network/endpoint_test.go
+++ b/pkg/tcpip/transport/internal/network/endpoint_test.go
@@ -15,6 +15,7 @@
package network_test
import (
+ "fmt"
"testing"
"github.com/google/go-cmp/cmp"
@@ -24,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -33,17 +35,15 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
-func TestEndpointStateTransitions(t *testing.T) {
- const (
- nicID = 1
- )
+var (
+ ipv4NICAddr = testutil.MustParse4("1.2.3.4")
+ ipv6NICAddr = testutil.MustParse6("a::1")
+ ipv4RemoteAddr = testutil.MustParse4("6.7.8.9")
+ ipv6RemoteAddr = testutil.MustParse6("b::1")
+)
- var (
- ipv4NICAddr = testutil.MustParse4("1.2.3.4")
- ipv6NICAddr = testutil.MustParse6("a::1")
- ipv4RemoteAddr = testutil.MustParse4("6.7.8.9")
- ipv6RemoteAddr = testutil.MustParse6("b::1")
- )
+func TestEndpointStateTransitions(t *testing.T) {
+ const nicID = 1
data := buffer.View([]byte{1, 2, 4, 5})
v4Checker := func(t *testing.T, b buffer.View) {
@@ -139,6 +139,7 @@ func TestEndpointStateTransitions(t *testing.T) {
var ops tcpip.SocketOptions
var ep network.Endpoint
ep.Init(s, test.netProto, udp.ProtocolNumber, &ops)
+ defer ep.Close()
if state := ep.State(); state != transport.DatagramEndpointStateInitial {
t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateInitial)
}
@@ -207,3 +208,94 @@ func TestEndpointStateTransitions(t *testing.T) {
})
}
}
+
+func TestBindNICID(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ bindAddr tcpip.Address
+ unicast bool
+ }{
+ {
+ name: "IPv4 multicast",
+ netProto: ipv4.ProtocolNumber,
+ bindAddr: header.IPv4AllSystems,
+ unicast: false,
+ },
+ {
+ name: "IPv6 multicast",
+ netProto: ipv6.ProtocolNumber,
+ bindAddr: header.IPv6AllNodesMulticastAddress,
+ unicast: false,
+ },
+ {
+ name: "IPv4 unicast",
+ netProto: ipv4.ProtocolNumber,
+ bindAddr: ipv4NICAddr,
+ unicast: true,
+ },
+ {
+ name: "IPv6 unicast",
+ netProto: ipv6.ProtocolNumber,
+ bindAddr: ipv6NICAddr,
+ unicast: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, testBindNICID := range []tcpip.NICID{0, nicID} {
+ t.Run(fmt.Sprintf("BindNICID=%d", testBindNICID), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: &faketime.NullClock{},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4NICAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4NICAddr, err)
+ }
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6NICAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6NICAddr, err)
+ }
+
+ var ops tcpip.SocketOptions
+ var ep network.Endpoint
+ ep.Init(s, test.netProto, udp.ProtocolNumber, &ops)
+ defer ep.Close()
+ if ep.WasBound() {
+ t.Fatal("got ep.WasBound() = true, want = false")
+ }
+ wantInfo := stack.TransportEndpointInfo{NetProto: test.netProto, TransProto: udp.ProtocolNumber}
+ if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" {
+ t.Fatalf("ep.Info() mismatch (-want +got):\n%s", diff)
+ }
+
+ bindAddr := tcpip.FullAddress{Addr: test.bindAddr, NIC: testBindNICID}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
+ }
+ if !ep.WasBound() {
+ t.Error("got ep.WasBound() = false, want = true")
+ }
+ wantInfo.ID = stack.TransportEndpointID{LocalAddress: bindAddr.Addr}
+ wantInfo.BindAddr = bindAddr.Addr
+ wantInfo.BindNICID = bindAddr.NIC
+ if test.unicast {
+ wantInfo.RegisterNICID = nicID
+ } else {
+ wantInfo.RegisterNICID = bindAddr.NIC
+ }
+ if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" {
+ t.Errorf("ep.Info() mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index f171a16f8..4255457f9 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -547,7 +547,7 @@ func (e *endpoint) Disconnect() tcpip.Error {
info := e.net.Info()
info.ID.LocalPort = e.localPort
info.ID.RemotePort = e.remotePort
- if info.BindNICID != 0 || info.ID.LocalAddress == "" {
+ if e.net.WasBound() {
var err tcpip.Error
id = stack.TransportEndpointID{
LocalPort: info.ID.LocalPort,
diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc
index d58b57c8b..b9af5cbdd 100644
--- a/test/syscalls/linux/udp_socket.cc
+++ b/test/syscalls/linux/udp_socket.cc
@@ -602,6 +602,67 @@ TEST_P(UdpSocketTest, DisconnectAfterBind) {
SyscallFailsWithErrno(ENOTCONN));
}
+void ConnectThenDisconnect(const FileDescriptor& sock,
+ const sockaddr* bind_addr,
+ const socklen_t expected_addrlen) {
+ // Connect the bound socket.
+ ASSERT_THAT(connect(sock.get(), bind_addr, expected_addrlen),
+ SyscallSucceeds());
+
+ // Disconnect.
+ {
+ sockaddr_storage unspec = {.ss_family = AF_UNSPEC};
+ ASSERT_THAT(connect(sock.get(), AsSockAddr(&unspec), sizeof(unspec)),
+ SyscallSucceeds());
+ }
+ {
+ // Check that we're not in a bound state.
+ sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ ASSERT_THAT(getsockname(sock.get(), AsSockAddr(&addr), &addrlen),
+ SyscallSucceeds());
+ ASSERT_EQ(addrlen, expected_addrlen);
+ // Everything should be the zero value except the address family.
+ sockaddr_storage expected = {
+ .ss_family = bind_addr->sa_family,
+ };
+ EXPECT_EQ(memcmp(&expected, &addr, expected_addrlen), 0);
+ }
+
+ {
+ // We are not connected so we have no peer.
+ sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(getpeername(sock.get(), AsSockAddr(&addr), &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+ }
+}
+
+TEST_P(UdpSocketTest, DisconnectAfterBindToUnspecAndConnect) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ sockaddr_storage unspec = {.ss_family = AF_UNSPEC};
+ int bind_res = bind(sock_.get(), AsSockAddr(&unspec), sizeof(unspec));
+ if (IsRunningOnGvisor() && !IsRunningWithHostinet()) {
+ // TODO(https://gvisor.dev/issue/6575): Match Linux's behaviour.
+ ASSERT_THAT(bind_res, SyscallFailsWithErrno(EINVAL));
+ } else if (GetFamily() == AF_INET) {
+ // Linux allows this for undocumented compatibility reasons:
+ // https://github.com/torvalds/linux/commit/29c486df6a208432b370bd4be99ae1369ede28d8.
+ ASSERT_THAT(bind_res, SyscallSucceeds());
+ } else {
+ ASSERT_THAT(bind_res, SyscallFailsWithErrno(EAFNOSUPPORT));
+ }
+
+ ASSERT_NO_FATAL_FAILURE(ConnectThenDisconnect(sock_, bind_addr_, addrlen_));
+}
+
+TEST_P(UdpSocketTest, DisconnectAfterConnectWithoutBind) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ ASSERT_NO_FATAL_FAILURE(ConnectThenDisconnect(sock_, bind_addr_, addrlen_));
+}
+
TEST_P(UdpSocketTest, BindToAnyConnnectToLocalhost) {
ASSERT_NO_ERRNO(BindAny());