summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/udp
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport/udp')
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go81
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go119
2 files changed, 154 insertions, 46 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index a4ff29a7d..c9cbed8f4 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -31,6 +31,7 @@ type udpPacket struct {
senderAddress tcpip.FullAddress
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
timestamp int64
+ tos uint8
}
// EndpointState represents the state of a UDP endpoint.
@@ -113,6 +114,10 @@ type endpoint struct {
// applied while sending packets. Defaults to 0 as on Linux.
sendTOS uint8
+ // receiveTOS determines if the incoming IPv4 TOS header field is passed
+ // as ancillary data to ControlMessages on Read.
+ receiveTOS bool
+
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -243,7 +248,18 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
*addr = p.senderAddress
}
- return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil
+ cm := tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: p.timestamp,
+ }
+ e.mu.RLock()
+ receiveTOS := e.receiveTOS
+ e.mu.RUnlock()
+ if receiveTOS {
+ cm.HasTOS = true
+ cm.TOS = p.tos
+ }
+ return p.data.ToView(), cm, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
@@ -402,7 +418,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, tcpip.ErrBroadcastDisabled
}
- netProto, err := e.checkV4Mapped(to, false)
+ netProto, err := e.checkV4Mapped(to)
if err != nil {
return 0, nil, err
}
@@ -458,6 +474,12 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
switch opt {
+ case tcpip.ReceiveTOSOption:
+ e.mu.Lock()
+ e.receiveTOS = v
+ e.mu.Unlock()
+ return nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -501,7 +523,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
defer e.mu.Unlock()
fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
- netProto, err := e.checkV4Mapped(&fa, false)
+ netProto, err := e.checkV4Mapped(&fa)
if err != nil {
return err
}
@@ -664,15 +686,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
+ case tcpip.ReceiveTOSOption:
+ e.mu.RLock()
+ v := e.receiveTOS
+ e.mu.RUnlock()
+ return v, nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
return false, tcpip.ErrUnknownProtocolOption
}
- e.mu.Lock()
+ e.mu.RLock()
v := e.v6only
- e.mu.Unlock()
+ e.mu.RUnlock()
return v, nil
}
@@ -839,35 +867,12 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
return nil
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.NetProto
- if len(addr.Addr) == 0 {
- return netProto, nil
- }
- if header.IsV4MappedAddress(addr.Addr) {
- // Fail if using a v4 mapped address on a v6only endpoint.
- if e.v6only {
- return 0, tcpip.ErrNoRoute
- }
-
- netProto = header.IPv4ProtocolNumber
- addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
- if addr.Addr == header.IPv4Any {
- addr.Addr = ""
- }
-
- // Fail if we are bound to an IPv6 address.
- if !allowMismatch && len(e.ID.LocalAddress) == 16 {
- return 0, tcpip.ErrNetworkUnreachable
- }
- }
-
- // Fail if we're bound to an address length different from the one we're
- // checking.
- if l := len(e.ID.LocalAddress); l != 0 && l != len(addr.Addr) {
- return 0, tcpip.ErrInvalidEndpointState
+func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only)
+ if err != nil {
+ return 0, err
}
-
+ *addr = unwrapped
return netProto, nil
}
@@ -916,7 +921,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- netProto, err := e.checkV4Mapped(&addr, false)
+ netProto, err := e.checkV4Mapped(&addr)
if err != nil {
return err
}
@@ -1074,7 +1079,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr, true)
+ netProto, err := e.checkV4Mapped(&addr)
if err != nil {
return err
}
@@ -1238,6 +1243,12 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
e.rcvList.PushBack(packet)
e.rcvBufSize += pkt.Data.Size()
+ // Save any useful information from the network header to the packet.
+ switch r.NetProto {
+ case header.IPv4ProtocolNumber:
+ packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS()
+ }
+
packet.timestamp = e.stack.NowNanoseconds()
e.rcvMu.Unlock()
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index d33507156..c6927cfe3 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -56,6 +56,7 @@ const (
multicastAddr = "\xe8\x2b\xd3\xea"
multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
broadcastAddr = header.IPv4Broadcast
+ testTOS = 0x80
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
@@ -273,11 +274,16 @@ type testContext struct {
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
-
- s := stack.New(stack.Options{
+ return newDualTestContextWithOptions(t, mtu, stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
})
+}
+
+func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Options) *testContext {
+ t.Helper()
+
+ s := stack.New(options)
ep := channel.New(256, mtu, "")
wep := stack.LinkEndpoint(ep)
@@ -453,6 +459,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool
ip := header.IPv4(buf)
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
+ TOS: testTOS,
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(udp.ProtocolNumber),
@@ -552,8 +559,8 @@ func TestBindToDeviceOption(t *testing.T) {
// testReadInternal sends a packet of the given test flow into the stack by
// injecting it into the link endpoint. It then attempts to read it from the
// UDP endpoint and depending on if this was expected to succeed verifies its
-// correctness.
-func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) {
+// correctness including any additional checker functions provided.
+func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) {
c.t.Helper()
payload := newPayload()
@@ -568,12 +575,12 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
var addr tcpip.FullAddress
- v, _, err := c.ep.Read(&addr)
+ v, cm, err := c.ep.Read(&addr)
if err == tcpip.ErrWouldBlock {
// Wait for data to become available.
select {
case <-ch:
- v, _, err = c.ep.Read(&addr)
+ v, cm, err = c.ep.Read(&addr)
case <-time.After(300 * time.Millisecond):
if packetShouldBeDropped {
@@ -606,15 +613,21 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
if !bytes.Equal(payload, v) {
c.t.Fatalf("bad payload: got %x, want %x", v, payload)
}
+
+ // Run any checkers against the ControlMessages.
+ for _, f := range checkers {
+ f(c.t, cm)
+ }
+
c.checkEndpointReadStats(1, epstats, err)
}
// testRead sends a packet of the given test flow into the stack by injecting it
// into the link endpoint. It then reads it from the UDP endpoint and verifies
-// its correctness.
-func testRead(c *testContext, flow testFlow) {
+// its correctness including any additional checker functions provided.
+func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) {
c.t.Helper()
- testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */)
+ testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...)
}
// testFailingRead sends a packet of the given test flow into the stack by
@@ -755,6 +768,49 @@ func TestV6ReadOnV6(t *testing.T) {
testRead(c, unicastV6)
}
+// TestV4ReadSelfSource checks that packets coming from a local IP address are
+// correctly dropped when handleLocal is true and not otherwise.
+func TestV4ReadSelfSource(t *testing.T) {
+ for _, tt := range []struct {
+ name string
+ handleLocal bool
+ wantErr *tcpip.Error
+ wantInvalidSource uint64
+ }{
+ {"HandleLocal", false, nil, 0},
+ {"NoHandleLocal", true, tcpip.ErrWouldBlock, 1},
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ HandleLocal: tt.handleLocal,
+ })
+ defer c.cleanup()
+
+ c.createEndpointForFlow(unicastV4)
+
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ h.srcAddr = h.dstAddr
+
+ c.injectV4Packet(payload, &h, true /* valid */)
+
+ if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource {
+ t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
+ }
+
+ if _, _, err := c.ep.Read(nil); err != tt.wantErr {
+ t.Errorf("c.ep.Read() got error %v, want %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
func TestV4ReadOnV4(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1288,7 +1344,7 @@ func TestTOSV4(t *testing.T) {
c.createEndpointForFlow(flow)
- const tos = 0xC0
+ const tos = testTOS
var v tcpip.IPv4TOSOption
if err := c.ep.GetSockOpt(&v); err != nil {
c.t.Errorf("GetSockopt failed: %s", err)
@@ -1323,7 +1379,7 @@ func TestTOSV6(t *testing.T) {
c.createEndpointForFlow(flow)
- const tos = 0xC0
+ const tos = testTOS
var v tcpip.IPv6TrafficClassOption
if err := c.ep.GetSockOpt(&v); err != nil {
c.t.Errorf("GetSockopt failed: %s", err)
@@ -1350,6 +1406,47 @@ func TestTOSV6(t *testing.T) {
}
}
+func TestReceiveTOSV4(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, broadcast} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Verify that setting and reading the option works.
+ v, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption)
+ if err != nil {
+ c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err)
+ }
+ // Test for expected default value.
+ if v != false {
+ c.t.Errorf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", v, false)
+ }
+
+ want := true
+ if err := c.ep.SetSockOptBool(tcpip.ReceiveTOSOption, want); err != nil {
+ c.t.Fatalf("SetSockOptBool(tcpip.ReceiveTOSOption, %t) failed: %s", want, err)
+ }
+
+ got, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption)
+ if err != nil {
+ c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err)
+ }
+ if got != want {
+ c.t.Fatalf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", got, want)
+ }
+
+ // Verify that the correct received TOS is handed through as
+ // ancillary data to the ControlMessages struct.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatal("Bind failed:", err)
+ }
+ testRead(c, flow, checker.ReceiveTOS(testTOS))
+ })
+ }
+}
+
func TestMulticastInterfaceOption(t *testing.T) {
for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {