summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/udp
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport/udp')
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go37
-rw-r--r--pkg/tcpip/transport/udp/protocol.go4
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go70
3 files changed, 106 insertions, 5 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 2ad7978c2..6e87245b7 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -106,6 +106,10 @@ type endpoint struct {
bindToDevice tcpip.NICID
broadcast bool
+ // sendTOS represents IPv4 TOS or IPv6 TrafficClass,
+ // applied while sending packets. Defaults to 0 as on Linux.
+ sendTOS uint8
+
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -429,7 +433,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
useDefaultTTL = false
}
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL); err != nil {
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS); err != nil {
return 0, nil, err
}
return int64(len(v)), nil, nil
@@ -628,6 +632,18 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+
+ case tcpip.IPv4TOSOption:
+ e.mu.Lock()
+ e.sendTOS = uint8(v)
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.Lock()
+ e.sendTOS = uint8(v)
+ e.mu.Unlock()
+ return nil
}
return nil
}
@@ -748,6 +764,18 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.IPv4TOSOption:
+ e.mu.RLock()
+ *o = tcpip.IPv4TOSOption(e.sendTOS)
+ e.mu.RUnlock()
+ return nil
+
+ case *tcpip.IPv6TrafficClassOption:
+ e.mu.RLock()
+ *o = tcpip.IPv6TrafficClassOption(e.sendTOS)
+ e.mu.RUnlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -755,7 +783,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool) *tcpip.Error {
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8) *tcpip.Error {
// Allocate a buffer for the UDP header.
hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
@@ -778,7 +806,10 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
udp.SetChecksum(^udp.CalculateChecksum(xsum))
}
- if err := r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl, useDefaultTTL); err != nil {
+ if useDefaultTTL {
+ ttl = r.DefaultTTL()
+ }
+ if err := r.WritePacket(nil /* gso */, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil {
r.Stats().UDP.PacketSendErrors.Increment()
return err
}
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 0b24bbc85..de026880f 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -130,7 +130,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
pkt.SetType(header.ICMPv4DstUnreachable)
pkt.SetCode(header.ICMPv4PortUnreachable)
pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload))
- r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv4ProtocolNumber, 0, true /* useDefaultTTL */)
+ r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS})
case header.IPv6AddressSize:
if !r.Stack().AllowICMPMessage() {
@@ -164,7 +164,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
pkt.SetType(header.ICMPv6DstUnreachable)
pkt.SetCode(header.ICMPv6PortUnreachable)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload))
- r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv6ProtocolNumber, 0, true /* useDefaultTTL */)
+ r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS})
}
return true
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 4ada73475..b724d788c 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -1267,6 +1267,76 @@ func TestSetTTL(t *testing.T) {
}
}
+func TestTOSV4(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ const tos = 0xC0
+ var v tcpip.IPv4TOSOption
+ if err := c.ep.GetSockOpt(&v); err != nil {
+ c.t.Errorf("GetSockopt failed: %s", err)
+ }
+ // Test for expected default value.
+ if v != 0 {
+ c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0)
+ }
+
+ if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil {
+ c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err)
+ }
+
+ if err := c.ep.GetSockOpt(&v); err != nil {
+ c.t.Errorf("GetSockopt failed: %s", err)
+ }
+
+ if want := tcpip.IPv4TOSOption(tos); v != want {
+ c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ }
+
+ testWrite(c, flow, checker.TOS(tos, 0))
+ })
+ }
+}
+
+func TestTOSV6(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ const tos = 0xC0
+ var v tcpip.IPv6TrafficClassOption
+ if err := c.ep.GetSockOpt(&v); err != nil {
+ c.t.Errorf("GetSockopt failed: %s", err)
+ }
+ // Test for expected default value.
+ if v != 0 {
+ c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0)
+ }
+
+ if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil {
+ c.t.Errorf("SetSockOpt failed: %s", err)
+ }
+
+ if err := c.ep.GetSockOpt(&v); err != nil {
+ c.t.Errorf("GetSockopt failed: %s", err)
+ }
+
+ if want := tcpip.IPv6TrafficClassOption(tos); v != want {
+ c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ }
+
+ testWrite(c, flow, checker.TOS(tos, 0))
+ })
+ }
+}
+
func TestMulticastInterfaceOption(t *testing.T) {
for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {