summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/udp
diff options
context:
space:
mode:
authorRyan Heacock <rheacock@google.com>2019-12-24 08:48:14 -0800
committergVisor bot <gvisor-bot@google.com>2019-12-24 08:49:39 -0800
commite013c48c78c9a7daf245b7de9563e3a0bd8a1e97 (patch)
tree4569d1851e72e61faebc63c82997555afdde04a9 /pkg/tcpip/transport/udp
parent574e988f2bc6060078a17f37a377441703c52a22 (diff)
Enable IP_RECVTOS socket option for datagram sockets
Added the ability to get/set the IP_RECVTOS socket option on UDP endpoints. If enabled, TOS from the incoming Network Header passed as ancillary data in the ControlMessages. Test: * Added unit test to udp_test.go that tests getting/setting as well as verifying that we receive expected TOS from incoming packet. * Added a syscall test PiperOrigin-RevId: 287029703
Diffstat (limited to 'pkg/tcpip/transport/udp')
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go31
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go69
2 files changed, 90 insertions, 10 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 1ac4705af..269470ed4 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -32,6 +32,7 @@ type udpPacket struct {
senderAddress tcpip.FullAddress
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
timestamp int64
+ tos uint8
}
// EndpointState represents the state of a UDP endpoint.
@@ -114,6 +115,10 @@ type endpoint struct {
// applied while sending packets. Defaults to 0 as on Linux.
sendTOS uint8
+ // receiveTOS determines if the incoming IPv4 TOS header field is passed
+ // as ancillary data to ControlMessages on Read.
+ receiveTOS bool
+
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -244,7 +249,12 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
*addr = p.senderAddress
}
- return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil
+ return p.data.ToView(), tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: p.timestamp,
+ HasTOS: e.receiveTOS,
+ TOS: p.tos,
+ }, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
@@ -656,6 +666,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.sendTOS = uint8(v)
e.mu.Unlock()
return nil
+
+ case tcpip.ReceiveTOSOption:
+ e.mu.Lock()
+ e.receiveTOS = bool(v)
+ e.mu.Unlock()
+ return nil
}
return nil
}
@@ -792,6 +808,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.mu.RUnlock()
return nil
+ case *tcpip.ReceiveTOSOption:
+ e.mu.RLock()
+ *o = tcpip.ReceiveTOSOption(e.receiveTOS)
+ e.mu.RUnlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -1238,6 +1260,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
e.rcvList.PushBack(packet)
e.rcvBufSize += pkt.Data.Size()
+ // Save any useful information from the NetworkHeader to the packet.
+ switch r.NetProto {
+ case header.IPv4ProtocolNumber:
+ // This packet has already been validated before being passed up the stack.
+ packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS()
+ }
+
packet.timestamp = e.stack.NowNanoseconds()
e.rcvMu.Unlock()
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 7051a7a9c..43b8b35ba 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -56,6 +56,7 @@ const (
multicastAddr = "\xe8\x2b\xd3\xea"
multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
broadcastAddr = header.IPv4Broadcast
+ testTOS = 0x80
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
@@ -453,6 +454,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool
ip := header.IPv4(buf)
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
+ TOS: testTOS,
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(udp.ProtocolNumber),
@@ -556,8 +558,8 @@ func TestBindToDeviceOption(t *testing.T) {
// testReadInternal sends a packet of the given test flow into the stack by
// injecting it into the link endpoint. It then attempts to read it from the
// UDP endpoint and depending on if this was expected to succeed verifies its
-// correctness.
-func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) {
+// correctness including any additional checker functions provided.
+func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) {
c.t.Helper()
payload := newPayload()
@@ -572,12 +574,12 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
var addr tcpip.FullAddress
- v, _, err := c.ep.Read(&addr)
+ v, cm, err := c.ep.Read(&addr)
if err == tcpip.ErrWouldBlock {
// Wait for data to become available.
select {
case <-ch:
- v, _, err = c.ep.Read(&addr)
+ v, cm, err = c.ep.Read(&addr)
case <-time.After(300 * time.Millisecond):
if packetShouldBeDropped {
@@ -610,15 +612,21 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
if !bytes.Equal(payload, v) {
c.t.Fatalf("bad payload: got %x, want %x", v, payload)
}
+
+ // Run any checkers against the ControlMessages.
+ for _, f := range checkers {
+ f(c.t, cm)
+ }
+
c.checkEndpointReadStats(1, epstats, err)
}
// testRead sends a packet of the given test flow into the stack by injecting it
// into the link endpoint. It then reads it from the UDP endpoint and verifies
-// its correctness.
-func testRead(c *testContext, flow testFlow) {
+// its correctness including any additional checker functions provided.
+func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) {
c.t.Helper()
- testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */)
+ testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...)
}
// testFailingRead sends a packet of the given test flow into the stack by
@@ -1286,7 +1294,7 @@ func TestTOSV4(t *testing.T) {
c.createEndpointForFlow(flow)
- const tos = 0xC0
+ const tos = testTOS
var v tcpip.IPv4TOSOption
if err := c.ep.GetSockOpt(&v); err != nil {
c.t.Errorf("GetSockopt failed: %s", err)
@@ -1321,7 +1329,7 @@ func TestTOSV6(t *testing.T) {
c.createEndpointForFlow(flow)
- const tos = 0xC0
+ const tos = testTOS
var v tcpip.IPv6TrafficClassOption
if err := c.ep.GetSockOpt(&v); err != nil {
c.t.Errorf("GetSockopt failed: %s", err)
@@ -1348,6 +1356,49 @@ func TestTOSV6(t *testing.T) {
}
}
+func TestReceiveTOSV4(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, broadcast} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Verify that setting and reading the option works.
+ const recvTos = true
+ var v tcpip.ReceiveTOSOption
+ if err := c.ep.GetSockOpt(&v); err != nil {
+ c.t.Errorf("GetSockopt failed: %s", err)
+ }
+ // Test for expected default value.
+ if v != false {
+ c.t.Errorf("got GetSockOpt(...) = %t, want = %t", v, false)
+ }
+
+ if err := c.ep.SetSockOpt(tcpip.ReceiveTOSOption(recvTos)); err != nil {
+ c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.ReceiveTOSOption(recvTos), err)
+ }
+
+ if err := c.ep.GetSockOpt(&v); err != nil {
+ c.t.Errorf("GetSockopt failed: %s", err)
+ }
+
+ if want := tcpip.ReceiveTOSOption(recvTos); v != want {
+ c.t.Errorf("got GetSockOpt(...) = %t, want = %t", v, want)
+ }
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Verify that the correct received TOS is actually handed through as
+ // ancillary data to the ControlMessages struct.
+ testRead(c, flow, checker.ReceiveTOS(testTOS))
+ })
+ }
+}
+
func TestMulticastInterfaceOption(t *testing.T) {
for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {