diff options
-rw-r--r-- | pkg/tcpip/transport/internal/network/endpoint.go | 26 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 25 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 46 | ||||
-rw-r--r-- | test/syscalls/linux/udp_socket.cc | 144 |
4 files changed, 212 insertions, 29 deletions
diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go index 0dce60d89..c5b575e1c 100644 --- a/pkg/tcpip/transport/internal/network/endpoint.go +++ b/pkg/tcpip/transport/internal/network/endpoint.go @@ -60,10 +60,8 @@ type Endpoint struct { multicastAddr tcpip.Address // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. multicastNICID tcpip.NICID - // sendTOS represents IPv4 TOS or IPv6 TrafficClass, - // applied while sending packets. Defaults to 0 as on Linux. - // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. - sendTOS uint8 + ipv4TOS uint8 + ipv6TClass uint8 } // +stateify savable @@ -267,11 +265,21 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext return WriteContext{}, &tcpip.ErrBroadcastDisabled{} } + var tos uint8 + switch netProto := route.NetProto(); netProto { + case header.IPv4ProtocolNumber: + tos = e.ipv4TOS + case header.IPv6ProtocolNumber: + tos = e.ipv6TClass + default: + panic(fmt.Sprintf("invalid protocol number = %d", netProto)) + } + return WriteContext{ transProto: e.transProto, route: route, ttl: calculateTTL(route, e.ttl, e.multicastTTL), - tos: e.sendTOS, + tos: tos, owner: e.owner, }, nil } @@ -533,12 +541,12 @@ func (e *Endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { case tcpip.IPv4TOSOption: e.mu.Lock() - e.sendTOS = uint8(v) + e.ipv4TOS = uint8(v) e.mu.Unlock() case tcpip.IPv6TrafficClassOption: e.mu.Lock() - e.sendTOS = uint8(v) + e.ipv6TClass = uint8(v) e.mu.Unlock() } @@ -566,13 +574,13 @@ func (e *Endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { case tcpip.IPv4TOSOption: e.mu.RLock() - v := int(e.sendTOS) + v := int(e.ipv4TOS) e.mu.RUnlock() return v, nil case tcpip.IPv6TrafficClassOption: e.mu.RLock() - v := int(e.sendTOS) + v := int(e.ipv6TClass) e.mu.RUnlock() return v, nil diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 4b6bdc3be..f171a16f8 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -33,6 +33,7 @@ import ( // +stateify savable type udpPacket struct { udpPacketEntry + netProto tcpip.NetworkProtocolNumber senderAddress tcpip.FullAddress destinationAddress tcpip.FullAddress packetInfo tcpip.IPPacketInfo @@ -235,14 +236,21 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult HasTimestamp: true, Timestamp: p.receivedAt.UnixNano(), } - if e.ops.GetReceiveTOS() { - cm.HasTOS = true - cm.TOS = p.tos - } - if e.ops.GetReceiveTClass() { - cm.HasTClass = true - // Although TClass is an 8-bit value it's read in the CMsg as a uint32. - cm.TClass = uint32(p.tos) + + switch p.netProto { + case header.IPv4ProtocolNumber: + if e.ops.GetReceiveTOS() { + cm.HasTOS = true + cm.TOS = p.tos + } + case header.IPv6ProtocolNumber: + if e.ops.GetReceiveTClass() { + cm.HasTClass = true + // Although TClass is an 8-bit value it's read in the CMsg as a uint32. + cm.TClass = uint32(p.tos) + } + default: + panic(fmt.Sprintf("unrecognized network protocol = %d", p.netProto)) } if e.ops.GetReceivePacketInfo() { cm.HasIPPacketInfo = true @@ -888,6 +896,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB // Push new packet into receive list and increment the buffer size. packet := &udpPacket{ + netProto: pkt.NetworkProtocolNumber, senderAddress: tcpip.FullAddress{ NIC: pkt.NICID, Addr: id.RemoteAddress, diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 4008cacf2..554ce1de4 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -290,6 +290,7 @@ type testContext struct { t *testing.T linkEP *channel.Endpoint s *stack.Stack + nicID tcpip.NICID ep tcpip.Endpoint wq waiter.Queue @@ -301,6 +302,8 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext { } func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal bool) *testContext { + const nicID = 1 + t.Helper() options := stack.Options{ @@ -316,32 +319,33 @@ func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal boo if testing.Verbose() { wep = sniffer.New(ep) } - if err := s.CreateNIC(1, wep); err != nil { - t.Fatalf("CreateNIC failed: %s", err) + if err := s.CreateNIC(nicID, wep); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress failed: %s", err) + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err) } - if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %s", err) + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil { + t.Fatalf("AddAddress((%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, stackV6Addr, err) } s.SetRouteTable([]tcpip.Route{ { Destination: header.IPv4EmptySubnet, - NIC: 1, + NIC: nicID, }, { Destination: header.IPv6EmptySubnet, - NIC: 1, + NIC: nicID, }, }) return &testContext{ t: t, s: s, + nicID: nicID, linkEP: ep, } } @@ -1644,8 +1648,10 @@ func TestSetTTL(t *testing.T) { } } +var v4PacketFlows = [...]testFlow{unicastV4, multicastV4, broadcast, unicastV4in6, multicastV4in6, broadcastIn6} + func TestSetTOS(t *testing.T) { - for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { + for _, flow := range v4PacketFlows { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() @@ -1680,8 +1686,10 @@ func TestSetTOS(t *testing.T) { } } +var v6PacketFlows = [...]testFlow{unicastV6, unicastV6Only, multicastV6} + func TestSetTClass(t *testing.T) { - for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} { + for _, flow := range v6PacketFlows { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() @@ -1725,8 +1733,14 @@ func TestReceiveTosTClass(t *testing.T) { name string tests []testFlow }{ - {RcvTOSOpt, []testFlow{unicastV4, broadcast}}, - {RcvTClassOpt, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}}, + { + name: RcvTOSOpt, + tests: v4PacketFlows[:], + }, + { + name: RcvTClassOpt, + tests: v6PacketFlows[:], + }, } for _, testCase := range testCases { for _, flow := range testCase.tests { @@ -1737,6 +1751,14 @@ func TestReceiveTosTClass(t *testing.T) { c.createEndpointForFlow(flow) name := testCase.name + if flow.isMulticast() { + netProto := flow.netProto() + addr := flow.getMcastAddr() + if err := c.s.JoinGroup(netProto, c.nicID, addr); err != nil { + c.t.Fatalf("JoinGroup(%d, %d, %s): %s", netProto, c.nicID, addr, err) + } + } + var optionGetter func() bool var optionSetter func(bool) switch name { diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index 3353e58cb..d58b57c8b 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -18,6 +18,8 @@ #include <netinet/ip_icmp.h> #include <ctime> +#include <utility> +#include <vector> #ifdef __linux__ #include <linux/errqueue.h> @@ -1685,6 +1687,148 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) { ASSERT_EQ(tv.tv_usec, tv2.tv_usec); } +// TOS and TCLASS values may be different but IPv6 sockets with IPv4-mapped-IPv6 +// addresses use TOS (IPv4), not TCLASS (IPv6). +TEST_P(UdpSocketTest, DifferentTOSAndTClass) { + const int kFamily = GetFamily(); + constexpr int kToS = IPTOS_LOWDELAY; + constexpr int kTClass = IPTOS_THROUGHPUT; + ASSERT_NE(kToS, kTClass); + + if (kFamily == AF_INET6) { + ASSERT_THAT(setsockopt(sock_.get(), SOL_IPV6, IPV6_TCLASS, &kTClass, + sizeof(kTClass)), + SyscallSucceeds()); + + // Marking an IPv6 socket as IPv6 only should not affect the ability to + // configure IPv4 socket options as the V6ONLY flag may later be disabled so + // that applications may use the socket to send/receive IPv4 packets. + constexpr int on = 1; + ASSERT_THAT(setsockopt(sock_.get(), SOL_IPV6, IPV6_V6ONLY, &on, sizeof(on)), + SyscallSucceeds()); + } + + ASSERT_THAT(setsockopt(sock_.get(), SOL_IP, IP_TOS, &kToS, sizeof(kToS)), + SyscallSucceeds()); + + if (kFamily == AF_INET6) { + int got_tclass; + socklen_t got_tclass_len = sizeof(got_tclass); + ASSERT_THAT(getsockopt(sock_.get(), SOL_IPV6, IPV6_TCLASS, &got_tclass, + &got_tclass_len), + SyscallSucceeds()); + ASSERT_EQ(got_tclass_len, sizeof(got_tclass)); + EXPECT_EQ(got_tclass, kTClass); + } + + { + int got_tos; + socklen_t got_tos_len = sizeof(got_tos); + ASSERT_THAT(getsockopt(sock_.get(), SOL_IP, IP_TOS, &got_tos, &got_tos_len), + SyscallSucceeds()); + ASSERT_EQ(got_tos_len, sizeof(got_tos)); + EXPECT_EQ(got_tos, kToS); + } + + auto test_send = [this](sockaddr_storage addr, + std::function<void(const cmsghdr*)> cb) { + FileDescriptor bind = ASSERT_NO_ERRNO_AND_VALUE( + Socket(addr.ss_family, SOCK_DGRAM, IPPROTO_UDP)); + ASSERT_NO_ERRNO(BindSocket(bind.get(), reinterpret_cast<sockaddr*>(&addr))); + ASSERT_THAT(setsockopt(bind.get(), SOL_IP, IP_RECVTOS, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + if (addr.ss_family == AF_INET6) { + ASSERT_THAT(setsockopt(bind.get(), SOL_IPV6, IPV6_RECVTCLASS, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + } + + char sent_data[1024]; + iovec sent_iov = { + .iov_base = sent_data, + .iov_len = sizeof(sent_data), + }; + msghdr sent_msg = { + .msg_name = &addr, + .msg_namelen = sizeof(addr), + .msg_iov = &sent_iov, + .msg_iovlen = 1, + }; + ASSERT_THAT(RetryEINTR(sendmsg)(sock_.get(), &sent_msg, 0), + SyscallSucceedsWithValue(sizeof(sent_data))); + + char received_data[sizeof(sent_data) + 1]; + iovec received_iov = { + .iov_base = received_data, + .iov_len = sizeof(received_data), + }; + std::vector<char> received_cmsgbuf(CMSG_SPACE(sizeof(int8_t))); + msghdr received_msg = { + .msg_iov = &received_iov, + .msg_iovlen = 1, + .msg_control = received_cmsgbuf.data(), + .msg_controllen = static_cast<socklen_t>(received_cmsgbuf.size()), + }; + ASSERT_THAT(RetryEINTR(recvmsg)(bind.get(), &received_msg, 0), + SyscallSucceedsWithValue(sizeof(sent_data))); + + cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); + ASSERT_NE(cmsg, nullptr); + ASSERT_NO_FATAL_FAILURE(cb(cmsg)); + EXPECT_EQ(CMSG_NXTHDR(&received_msg, cmsg), nullptr); + }; + + if (kFamily == AF_INET6) { + SCOPED_TRACE( + "Send IPv4 loopback packet using IPv6 socket via IPv4-mapped-IPv6"); + + constexpr int off = 0; + ASSERT_THAT( + setsockopt(sock_.get(), SOL_IPV6, IPV6_V6ONLY, &off, sizeof(off)), + SyscallSucceeds()); + + // Send a packet and make sure that the ToS value in the IPv4 header is + // the configured IPv4 ToS Value and not the IPv6 Traffic Class value even + // though we use an IPv6 socket to send an IPv4 packet. + ASSERT_NO_FATAL_FAILURE( + test_send(V4MappedLoopback().addr, [kToS](const cmsghdr* cmsg) { + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int8_t))); + EXPECT_EQ(cmsg->cmsg_level, SOL_IP); + EXPECT_EQ(cmsg->cmsg_type, IP_TOS); + int8_t received; + memcpy(&received, CMSG_DATA(cmsg), sizeof(received)); + EXPECT_EQ(received, kToS); + })); + } + + { + SCOPED_TRACE("Send loopback packet"); + + ASSERT_NO_FATAL_FAILURE(test_send( + InetLoopbackAddr(), [kFamily, kTClass, kToS](const cmsghdr* cmsg) { + switch (kFamily) { + case AF_INET: { + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int8_t))); + EXPECT_EQ(cmsg->cmsg_level, SOL_IP); + EXPECT_EQ(cmsg->cmsg_type, IP_TOS); + int8_t received; + memcpy(&received, CMSG_DATA(cmsg), sizeof(received)); + EXPECT_EQ(received, kToS); + } break; + case AF_INET6: { + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int32_t))); + EXPECT_EQ(cmsg->cmsg_level, SOL_IPV6); + EXPECT_EQ(cmsg->cmsg_type, IPV6_TCLASS); + int32_t received; + memcpy(&received, CMSG_DATA(cmsg), sizeof(received)); + EXPECT_EQ(received, kTClass); + } break; + } + })); + } +} + // Test that a socket with IP_TOS or IPV6_TCLASS set will set the TOS byte on // outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. |