summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/udp
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport/udp')
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go24
-rw-r--r--pkg/tcpip/transport/udp/protocol.go4
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go33
3 files changed, 55 insertions, 6 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 52f5af777..8ae83437e 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -83,6 +83,7 @@ type endpoint struct {
route stack.Route `state:"manual"`
dstPort uint16
v6only bool
+ ttl uint8
multicastTTL uint8
multicastAddr tcpip.Address
multicastNICID tcpip.NICID
@@ -374,12 +375,16 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, tcpip.ErrMessageTooLong
}
- ttl := route.DefaultTTL()
+ ttl := e.ttl
+ useDefaultTTL := ttl == 0
+
if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) {
ttl = e.multicastTTL
+ // Multicast allows a 0 TTL.
+ useDefaultTTL = false
}
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil {
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl, useDefaultTTL); err != nil {
return 0, nil, err
}
return int64(len(v)), nil, nil
@@ -414,6 +419,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.v6only = v != 0
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ e.ttl = uint8(v)
+ e.mu.Unlock()
+
case tcpip.MulticastTTLOption:
e.mu.Lock()
e.multicastTTL = uint8(v)
@@ -628,6 +638,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.TTLOption:
+ e.mu.Lock()
+ *o = tcpip.TTLOption(e.ttl)
+ e.mu.Unlock()
+ return nil
+
case *tcpip.MulticastTTLOption:
e.mu.Lock()
*o = tcpip.MulticastTTLOption(e.multicastTTL)
@@ -694,7 +710,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8) *tcpip.Error {
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool) *tcpip.Error {
// Allocate a buffer for the UDP header.
hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
@@ -720,7 +736,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
// Track count of packets sent.
r.Stats().UDP.PacketsSent.Increment()
- return r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl)
+ return r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl, useDefaultTTL)
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index f5cc932dd..0b24bbc85 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -130,7 +130,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
pkt.SetType(header.ICMPv4DstUnreachable)
pkt.SetCode(header.ICMPv4PortUnreachable)
pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload))
- r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv4ProtocolNumber, r.DefaultTTL())
+ r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv4ProtocolNumber, 0, true /* useDefaultTTL */)
case header.IPv6AddressSize:
if !r.Stack().AllowICMPMessage() {
@@ -164,7 +164,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
pkt.SetType(header.ICMPv6DstUnreachable)
pkt.SetCode(header.ICMPv6PortUnreachable)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload))
- r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv6ProtocolNumber, r.DefaultTTL())
+ r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv6ProtocolNumber, 0, true /* useDefaultTTL */)
}
return true
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index d1ba8d2ce..faa728b68 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -1207,6 +1207,39 @@ func TestTTL(t *testing.T) {
}
}
+func TestSetTTL(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} {
+ t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ if err := c.ep.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+
+ var p stack.NetworkProtocol
+ if flow.isV4() {
+ p = ipv4.NewProtocol()
+ } else {
+ p = ipv6.NewProtocol()
+ }
+ ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ep.Close()
+
+ testWrite(c, flow, checker.TTL(wantTTL))
+ })
+ }
+ })
+ }
+}
+
func TestMulticastInterfaceOption(t *testing.T) {
for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {