summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/internal/network/BUILD1
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint.go23
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_test.go112
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go2
-rw-r--r--test/syscalls/linux/udp_socket.cc61
5 files changed, 183 insertions, 16 deletions
diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD
index d10e3f13a..d6d3f52a3 100644
--- a/pkg/tcpip/transport/internal/network/BUILD
+++ b/pkg/tcpip/transport/internal/network/BUILD
@@ -32,6 +32,7 @@ go_test(
"//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go
index c5b575e1c..09b629022 100644
--- a/pkg/tcpip/transport/internal/network/endpoint.go
+++ b/pkg/tcpip/transport/internal/network/endpoint.go
@@ -44,8 +44,9 @@ type Endpoint struct {
state uint32
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- info stack.TransportEndpointInfo
+ mu sync.RWMutex `state:"nosave"`
+ wasBound bool
+ info stack.TransportEndpointInfo
// owner is the owner of transmitted packets.
owner tcpip.PacketOwner
writeShutdown bool
@@ -248,6 +249,9 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext
nicID = e.info.BindNICID
}
+ if nicID == 0 {
+ nicID = e.info.RegisterNICID
+ }
dst, netProto, err := e.checkV4MappedLocked(*opts.To)
if err != nil {
@@ -294,9 +298,9 @@ func (e *Endpoint) Disconnect() {
}
// Exclude ephemerally bound endpoints.
- if e.info.BindNICID != 0 || e.info.ID.LocalAddress == "" {
+ if e.wasBound {
e.info.ID = stack.TransportEndpointID{
- LocalAddress: e.info.ID.LocalAddress,
+ LocalAddress: e.info.BindAddr,
}
e.setEndpointState(transport.DatagramEndpointStateBound)
} else {
@@ -477,10 +481,12 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto
return err
}
+ e.wasBound = true
+
e.info.ID = stack.TransportEndpointID{
LocalAddress: addr.Addr,
}
- e.info.BindNICID = nicID
+ e.info.BindNICID = addr.NIC
e.info.RegisterNICID = nicID
e.info.BindAddr = addr.Addr
e.effectiveNetProto = netProto
@@ -488,6 +494,13 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto
return nil
}
+// WasBound returns true iff the endpoint was ever bound.
+func (e *Endpoint) WasBound() bool {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.wasBound
+}
+
// GetLocalAddress returns the address that the endpoint is bound to.
func (e *Endpoint) GetLocalAddress() tcpip.FullAddress {
e.mu.RLock()
diff --git a/pkg/tcpip/transport/internal/network/endpoint_test.go b/pkg/tcpip/transport/internal/network/endpoint_test.go
index 2c43eb66a..d99c961c3 100644
--- a/pkg/tcpip/transport/internal/network/endpoint_test.go
+++ b/pkg/tcpip/transport/internal/network/endpoint_test.go
@@ -15,6 +15,7 @@
package network_test
import (
+ "fmt"
"testing"
"github.com/google/go-cmp/cmp"
@@ -24,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -33,17 +35,15 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
-func TestEndpointStateTransitions(t *testing.T) {
- const (
- nicID = 1
- )
+var (
+ ipv4NICAddr = testutil.MustParse4("1.2.3.4")
+ ipv6NICAddr = testutil.MustParse6("a::1")
+ ipv4RemoteAddr = testutil.MustParse4("6.7.8.9")
+ ipv6RemoteAddr = testutil.MustParse6("b::1")
+)
- var (
- ipv4NICAddr = testutil.MustParse4("1.2.3.4")
- ipv6NICAddr = testutil.MustParse6("a::1")
- ipv4RemoteAddr = testutil.MustParse4("6.7.8.9")
- ipv6RemoteAddr = testutil.MustParse6("b::1")
- )
+func TestEndpointStateTransitions(t *testing.T) {
+ const nicID = 1
data := buffer.View([]byte{1, 2, 4, 5})
v4Checker := func(t *testing.T, b buffer.View) {
@@ -139,6 +139,7 @@ func TestEndpointStateTransitions(t *testing.T) {
var ops tcpip.SocketOptions
var ep network.Endpoint
ep.Init(s, test.netProto, udp.ProtocolNumber, &ops)
+ defer ep.Close()
if state := ep.State(); state != transport.DatagramEndpointStateInitial {
t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateInitial)
}
@@ -207,3 +208,94 @@ func TestEndpointStateTransitions(t *testing.T) {
})
}
}
+
+func TestBindNICID(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ bindAddr tcpip.Address
+ unicast bool
+ }{
+ {
+ name: "IPv4 multicast",
+ netProto: ipv4.ProtocolNumber,
+ bindAddr: header.IPv4AllSystems,
+ unicast: false,
+ },
+ {
+ name: "IPv6 multicast",
+ netProto: ipv6.ProtocolNumber,
+ bindAddr: header.IPv6AllNodesMulticastAddress,
+ unicast: false,
+ },
+ {
+ name: "IPv4 unicast",
+ netProto: ipv4.ProtocolNumber,
+ bindAddr: ipv4NICAddr,
+ unicast: true,
+ },
+ {
+ name: "IPv6 unicast",
+ netProto: ipv6.ProtocolNumber,
+ bindAddr: ipv6NICAddr,
+ unicast: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, testBindNICID := range []tcpip.NICID{0, nicID} {
+ t.Run(fmt.Sprintf("BindNICID=%d", testBindNICID), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: &faketime.NullClock{},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4NICAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4NICAddr, err)
+ }
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6NICAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6NICAddr, err)
+ }
+
+ var ops tcpip.SocketOptions
+ var ep network.Endpoint
+ ep.Init(s, test.netProto, udp.ProtocolNumber, &ops)
+ defer ep.Close()
+ if ep.WasBound() {
+ t.Fatal("got ep.WasBound() = true, want = false")
+ }
+ wantInfo := stack.TransportEndpointInfo{NetProto: test.netProto, TransProto: udp.ProtocolNumber}
+ if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" {
+ t.Fatalf("ep.Info() mismatch (-want +got):\n%s", diff)
+ }
+
+ bindAddr := tcpip.FullAddress{Addr: test.bindAddr, NIC: testBindNICID}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
+ }
+ if !ep.WasBound() {
+ t.Error("got ep.WasBound() = false, want = true")
+ }
+ wantInfo.ID = stack.TransportEndpointID{LocalAddress: bindAddr.Addr}
+ wantInfo.BindAddr = bindAddr.Addr
+ wantInfo.BindNICID = bindAddr.NIC
+ if test.unicast {
+ wantInfo.RegisterNICID = nicID
+ } else {
+ wantInfo.RegisterNICID = bindAddr.NIC
+ }
+ if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" {
+ t.Errorf("ep.Info() mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index f171a16f8..4255457f9 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -547,7 +547,7 @@ func (e *endpoint) Disconnect() tcpip.Error {
info := e.net.Info()
info.ID.LocalPort = e.localPort
info.ID.RemotePort = e.remotePort
- if info.BindNICID != 0 || info.ID.LocalAddress == "" {
+ if e.net.WasBound() {
var err tcpip.Error
id = stack.TransportEndpointID{
LocalPort: info.ID.LocalPort,
diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc
index d58b57c8b..b9af5cbdd 100644
--- a/test/syscalls/linux/udp_socket.cc
+++ b/test/syscalls/linux/udp_socket.cc
@@ -602,6 +602,67 @@ TEST_P(UdpSocketTest, DisconnectAfterBind) {
SyscallFailsWithErrno(ENOTCONN));
}
+void ConnectThenDisconnect(const FileDescriptor& sock,
+ const sockaddr* bind_addr,
+ const socklen_t expected_addrlen) {
+ // Connect the bound socket.
+ ASSERT_THAT(connect(sock.get(), bind_addr, expected_addrlen),
+ SyscallSucceeds());
+
+ // Disconnect.
+ {
+ sockaddr_storage unspec = {.ss_family = AF_UNSPEC};
+ ASSERT_THAT(connect(sock.get(), AsSockAddr(&unspec), sizeof(unspec)),
+ SyscallSucceeds());
+ }
+ {
+ // Check that we're not in a bound state.
+ sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ ASSERT_THAT(getsockname(sock.get(), AsSockAddr(&addr), &addrlen),
+ SyscallSucceeds());
+ ASSERT_EQ(addrlen, expected_addrlen);
+ // Everything should be the zero value except the address family.
+ sockaddr_storage expected = {
+ .ss_family = bind_addr->sa_family,
+ };
+ EXPECT_EQ(memcmp(&expected, &addr, expected_addrlen), 0);
+ }
+
+ {
+ // We are not connected so we have no peer.
+ sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(getpeername(sock.get(), AsSockAddr(&addr), &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+ }
+}
+
+TEST_P(UdpSocketTest, DisconnectAfterBindToUnspecAndConnect) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ sockaddr_storage unspec = {.ss_family = AF_UNSPEC};
+ int bind_res = bind(sock_.get(), AsSockAddr(&unspec), sizeof(unspec));
+ if (IsRunningOnGvisor() && !IsRunningWithHostinet()) {
+ // TODO(https://gvisor.dev/issue/6575): Match Linux's behaviour.
+ ASSERT_THAT(bind_res, SyscallFailsWithErrno(EINVAL));
+ } else if (GetFamily() == AF_INET) {
+ // Linux allows this for undocumented compatibility reasons:
+ // https://github.com/torvalds/linux/commit/29c486df6a208432b370bd4be99ae1369ede28d8.
+ ASSERT_THAT(bind_res, SyscallSucceeds());
+ } else {
+ ASSERT_THAT(bind_res, SyscallFailsWithErrno(EAFNOSUPPORT));
+ }
+
+ ASSERT_NO_FATAL_FAILURE(ConnectThenDisconnect(sock_, bind_addr_, addrlen_));
+}
+
+TEST_P(UdpSocketTest, DisconnectAfterConnectWithoutBind) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ ASSERT_NO_FATAL_FAILURE(ConnectThenDisconnect(sock_, bind_addr_, addrlen_));
+}
+
TEST_P(UdpSocketTest, BindToAnyConnnectToLocalhost) {
ASSERT_NO_ERRNO(BindAny());