diff options
-rw-r--r-- | pkg/tcpip/transport/internal/network/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/internal/network/endpoint.go | 23 | ||||
-rw-r--r-- | pkg/tcpip/transport/internal/network/endpoint_test.go | 112 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 2 | ||||
-rw-r--r-- | test/syscalls/linux/udp_socket.cc | 61 |
5 files changed, 183 insertions, 16 deletions
diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD index d10e3f13a..d6d3f52a3 100644 --- a/pkg/tcpip/transport/internal/network/BUILD +++ b/pkg/tcpip/transport/internal/network/BUILD @@ -32,6 +32,7 @@ go_test( "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go index c5b575e1c..09b629022 100644 --- a/pkg/tcpip/transport/internal/network/endpoint.go +++ b/pkg/tcpip/transport/internal/network/endpoint.go @@ -44,8 +44,9 @@ type Endpoint struct { state uint32 // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - info stack.TransportEndpointInfo + mu sync.RWMutex `state:"nosave"` + wasBound bool + info stack.TransportEndpointInfo // owner is the owner of transmitted packets. owner tcpip.PacketOwner writeShutdown bool @@ -248,6 +249,9 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext nicID = e.info.BindNICID } + if nicID == 0 { + nicID = e.info.RegisterNICID + } dst, netProto, err := e.checkV4MappedLocked(*opts.To) if err != nil { @@ -294,9 +298,9 @@ func (e *Endpoint) Disconnect() { } // Exclude ephemerally bound endpoints. - if e.info.BindNICID != 0 || e.info.ID.LocalAddress == "" { + if e.wasBound { e.info.ID = stack.TransportEndpointID{ - LocalAddress: e.info.ID.LocalAddress, + LocalAddress: e.info.BindAddr, } e.setEndpointState(transport.DatagramEndpointStateBound) } else { @@ -477,10 +481,12 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto return err } + e.wasBound = true + e.info.ID = stack.TransportEndpointID{ LocalAddress: addr.Addr, } - e.info.BindNICID = nicID + e.info.BindNICID = addr.NIC e.info.RegisterNICID = nicID e.info.BindAddr = addr.Addr e.effectiveNetProto = netProto @@ -488,6 +494,13 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto return nil } +// WasBound returns true iff the endpoint was ever bound. +func (e *Endpoint) WasBound() bool { + e.mu.RLock() + defer e.mu.RUnlock() + return e.wasBound +} + // GetLocalAddress returns the address that the endpoint is bound to. func (e *Endpoint) GetLocalAddress() tcpip.FullAddress { e.mu.RLock() diff --git a/pkg/tcpip/transport/internal/network/endpoint_test.go b/pkg/tcpip/transport/internal/network/endpoint_test.go index 2c43eb66a..d99c961c3 100644 --- a/pkg/tcpip/transport/internal/network/endpoint_test.go +++ b/pkg/tcpip/transport/internal/network/endpoint_test.go @@ -15,6 +15,7 @@ package network_test import ( + "fmt" "testing" "github.com/google/go-cmp/cmp" @@ -24,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -33,17 +35,15 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) -func TestEndpointStateTransitions(t *testing.T) { - const ( - nicID = 1 - ) +var ( + ipv4NICAddr = testutil.MustParse4("1.2.3.4") + ipv6NICAddr = testutil.MustParse6("a::1") + ipv4RemoteAddr = testutil.MustParse4("6.7.8.9") + ipv6RemoteAddr = testutil.MustParse6("b::1") +) - var ( - ipv4NICAddr = testutil.MustParse4("1.2.3.4") - ipv6NICAddr = testutil.MustParse6("a::1") - ipv4RemoteAddr = testutil.MustParse4("6.7.8.9") - ipv6RemoteAddr = testutil.MustParse6("b::1") - ) +func TestEndpointStateTransitions(t *testing.T) { + const nicID = 1 data := buffer.View([]byte{1, 2, 4, 5}) v4Checker := func(t *testing.T, b buffer.View) { @@ -139,6 +139,7 @@ func TestEndpointStateTransitions(t *testing.T) { var ops tcpip.SocketOptions var ep network.Endpoint ep.Init(s, test.netProto, udp.ProtocolNumber, &ops) + defer ep.Close() if state := ep.State(); state != transport.DatagramEndpointStateInitial { t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateInitial) } @@ -207,3 +208,94 @@ func TestEndpointStateTransitions(t *testing.T) { }) } } + +func TestBindNICID(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + bindAddr tcpip.Address + unicast bool + }{ + { + name: "IPv4 multicast", + netProto: ipv4.ProtocolNumber, + bindAddr: header.IPv4AllSystems, + unicast: false, + }, + { + name: "IPv6 multicast", + netProto: ipv6.ProtocolNumber, + bindAddr: header.IPv6AllNodesMulticastAddress, + unicast: false, + }, + { + name: "IPv4 unicast", + netProto: ipv4.ProtocolNumber, + bindAddr: ipv4NICAddr, + unicast: true, + }, + { + name: "IPv6 unicast", + netProto: ipv6.ProtocolNumber, + bindAddr: ipv6NICAddr, + unicast: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, testBindNICID := range []tcpip.NICID{0, nicID} { + t.Run(fmt.Sprintf("BindNICID=%d", testBindNICID), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: &faketime.NullClock{}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4NICAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4NICAddr, err) + } + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6NICAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6NICAddr, err) + } + + var ops tcpip.SocketOptions + var ep network.Endpoint + ep.Init(s, test.netProto, udp.ProtocolNumber, &ops) + defer ep.Close() + if ep.WasBound() { + t.Fatal("got ep.WasBound() = true, want = false") + } + wantInfo := stack.TransportEndpointInfo{NetProto: test.netProto, TransProto: udp.ProtocolNumber} + if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" { + t.Fatalf("ep.Info() mismatch (-want +got):\n%s", diff) + } + + bindAddr := tcpip.FullAddress{Addr: test.bindAddr, NIC: testBindNICID} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("ep.Bind(%#v): %s", bindAddr, err) + } + if !ep.WasBound() { + t.Error("got ep.WasBound() = false, want = true") + } + wantInfo.ID = stack.TransportEndpointID{LocalAddress: bindAddr.Addr} + wantInfo.BindAddr = bindAddr.Addr + wantInfo.BindNICID = bindAddr.NIC + if test.unicast { + wantInfo.RegisterNICID = nicID + } else { + wantInfo.RegisterNICID = bindAddr.NIC + } + if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" { + t.Errorf("ep.Info() mismatch (-want +got):\n%s", diff) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index f171a16f8..4255457f9 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -547,7 +547,7 @@ func (e *endpoint) Disconnect() tcpip.Error { info := e.net.Info() info.ID.LocalPort = e.localPort info.ID.RemotePort = e.remotePort - if info.BindNICID != 0 || info.ID.LocalAddress == "" { + if e.net.WasBound() { var err tcpip.Error id = stack.TransportEndpointID{ LocalPort: info.ID.LocalPort, diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index d58b57c8b..b9af5cbdd 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -602,6 +602,67 @@ TEST_P(UdpSocketTest, DisconnectAfterBind) { SyscallFailsWithErrno(ENOTCONN)); } +void ConnectThenDisconnect(const FileDescriptor& sock, + const sockaddr* bind_addr, + const socklen_t expected_addrlen) { + // Connect the bound socket. + ASSERT_THAT(connect(sock.get(), bind_addr, expected_addrlen), + SyscallSucceeds()); + + // Disconnect. + { + sockaddr_storage unspec = {.ss_family = AF_UNSPEC}; + ASSERT_THAT(connect(sock.get(), AsSockAddr(&unspec), sizeof(unspec)), + SyscallSucceeds()); + } + { + // Check that we're not in a bound state. + sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + ASSERT_THAT(getsockname(sock.get(), AsSockAddr(&addr), &addrlen), + SyscallSucceeds()); + ASSERT_EQ(addrlen, expected_addrlen); + // Everything should be the zero value except the address family. + sockaddr_storage expected = { + .ss_family = bind_addr->sa_family, + }; + EXPECT_EQ(memcmp(&expected, &addr, expected_addrlen), 0); + } + + { + // We are not connected so we have no peer. + sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getpeername(sock.get(), AsSockAddr(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); + } +} + +TEST_P(UdpSocketTest, DisconnectAfterBindToUnspecAndConnect) { + ASSERT_NO_ERRNO(BindLoopback()); + + sockaddr_storage unspec = {.ss_family = AF_UNSPEC}; + int bind_res = bind(sock_.get(), AsSockAddr(&unspec), sizeof(unspec)); + if (IsRunningOnGvisor() && !IsRunningWithHostinet()) { + // TODO(https://gvisor.dev/issue/6575): Match Linux's behaviour. + ASSERT_THAT(bind_res, SyscallFailsWithErrno(EINVAL)); + } else if (GetFamily() == AF_INET) { + // Linux allows this for undocumented compatibility reasons: + // https://github.com/torvalds/linux/commit/29c486df6a208432b370bd4be99ae1369ede28d8. + ASSERT_THAT(bind_res, SyscallSucceeds()); + } else { + ASSERT_THAT(bind_res, SyscallFailsWithErrno(EAFNOSUPPORT)); + } + + ASSERT_NO_FATAL_FAILURE(ConnectThenDisconnect(sock_, bind_addr_, addrlen_)); +} + +TEST_P(UdpSocketTest, DisconnectAfterConnectWithoutBind) { + ASSERT_NO_ERRNO(BindLoopback()); + + ASSERT_NO_FATAL_FAILURE(ConnectThenDisconnect(sock_, bind_addr_, addrlen_)); +} + TEST_P(UdpSocketTest, BindToAnyConnnectToLocalhost) { ASSERT_NO_ERRNO(BindAny()); |