diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 40 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 38 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 47 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 28 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 28 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 24 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/protocol.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 33 |
20 files changed, 265 insertions, 48 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 48f5c6f30..b3546471e 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -174,6 +174,10 @@ var Metrics = tcpip.Stats{ }, } +// DefaultTTL is linux's default TTL. All network protocols in all stacks used +// with this package must have this value set as their default TTL. +const DefaultTTL = 64 + const sizeOfInt32 int = 4 var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL) @@ -833,7 +837,7 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, return getSockOptIPv6(t, ep, name, outLen) case linux.SOL_IP: - return getSockOptIP(t, ep, name, outLen) + return getSockOptIP(t, ep, name, outLen, family) case linux.SOL_UDP, linux.SOL_ICMPV6, @@ -1176,8 +1180,25 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf } // getSockOptIP implements GetSockOpt when level is SOL_IP. -func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) { switch name { + case linux.IP_TTL: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TTLOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + // Fill in the default value, if needed. + if v == 0 { + v = DefaultTTL + } + + return int32(v), nil + case linux.IP_MULTICAST_TTL: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -1648,6 +1669,20 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s t.Kernel().EmitUnimplementedEvent(t) return syserr.ErrInvalidArgument + case linux.IP_TTL: + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + + // -1 means default TTL. + if v == -1 { + v = 0 + } else if v < 1 || v > 255 { + return syserr.ErrInvalidArgument + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TTLOption(v))) + case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, @@ -1673,7 +1708,6 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s linux.IP_RETOPTS, linux.IP_TOS, linux.IP_TRANSPARENT, - linux.IP_TTL, linux.IP_UNBLOCK_SOURCE, linux.IP_UNICAST_IF, linux.IP_XFRM_POLICY, diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index a25756443..c1cf6c222 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -95,7 +95,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL()); err != nil { + if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv4ProtocolNumber, 0, true /* useDefaultTTL */); err != nil { sent.Dropped.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index b7b07a6c1..162aa1b4d 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -39,6 +39,9 @@ const ( // TotalLength field of the ipv4 header. MaxTotalSize = 0xffff + // DefaultTTL is the default time-to-live value for this endpoint. + DefaultTTL = 64 + // buckets is the number of identifier buckets. buckets = 2048 ) @@ -70,7 +73,7 @@ func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWi // DefaultTTL is the default time-to-live value for this endpoint. func (e *endpoint) DefaultTTL() uint8 { - return 255 + return e.protocol.DefaultTTL() } // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus @@ -327,6 +330,11 @@ func (e *endpoint) Close() {} type protocol struct { ids []uint32 hashIV uint32 + + // defaultTTL is the current default TTL for the protocol. Only the + // uint8 portion of it is meaningful and it must be accessed + // atomically. + defaultTTL uint32 } // Number returns the ipv4 protocol number. @@ -352,12 +360,34 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { // SetOption implements NetworkProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch v := option.(type) { + case tcpip.DefaultTTLOption: + p.SetDefaultTTL(uint8(v)) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } } // Option implements NetworkProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch v := option.(type) { + case *tcpip.DefaultTTLOption: + *v = tcpip.DefaultTTLOption(p.DefaultTTL()) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// SetDefaultTTL sets the default TTL for endpoints created with this protocol. +func (p *protocol) SetDefaultTTL(ttl uint8) { + atomic.StoreUint32(&p.defaultTTL, uint32(ttl)) +} + +// DefaultTTL returns the default TTL for endpoints created with this protocol. +func (p *protocol) DefaultTTL() uint8 { + return uint8(atomic.LoadUint32(&p.defaultTTL)) } // calculateMTU calculates the network-layer payload MTU based on the link-layer @@ -391,5 +421,5 @@ func NewProtocol() stack.NetworkProtocol { } hashIV := r[buckets] - return &protocol{ids: ids, hashIV: hashIV} + return &protocol{ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL} } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index a53894c01..8b7500095 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -302,7 +302,7 @@ func TestFragmentation(t *testing.T) { Payload: payload.Clone([]buffer.View{}), } c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, hdr, payload, tcp.ProtocolNumber, 42) + err := c.Route.WritePacket(ft.gso, hdr, payload, tcp.ProtocolNumber, 42 /* ttl */, false /* useDefaultTTL */) if err != nil { t.Errorf("err got %v, want %v", err, nil) } @@ -349,7 +349,7 @@ func TestFragmentationErrors(t *testing.T) { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, tcp.ProtocolNumber, 42) + err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, tcp.ProtocolNumber, 42 /* ttl */, false /* useDefaultTTL */) for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index b4d0295bf..71c398027 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -154,7 +154,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V r.LocalAddress = targetAddr pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil { + if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, 0, true /* useDefaultTTL */); err != nil { sent.Dropped.Increment() return } @@ -185,7 +185,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V copy(pkt, h) pkt.SetType(header.ICMPv6EchoReply) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) - if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil { + if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv6ProtocolNumber, 0, true /* useDefaultTTL */); err != nil { sent.Dropped.Increment() return } @@ -262,7 +262,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. ip.Encode(&header.IPv6Fields{ PayloadLength: length, NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: defaultIPv6HopLimit, + HopLimit: ndpHopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 01f5a17ec..501be208e 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -144,7 +144,7 @@ func TestICMPCounts(t *testing.T) { ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: r.DefaultTTL(), + HopLimit: ndpHopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 7de6a4546..85c070c43 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -21,6 +21,8 @@ package ipv6 import ( + "sync/atomic" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -35,9 +37,9 @@ const ( // PayloadLength field of the ipv6 header. maxPayloadSize = 0xffff - // defaultIPv6HopLimit is the default hop limit for IPv6 Packets - // egressed by Netstack. - defaultIPv6HopLimit = 255 + // DefaultTTL is the default hop limit for IPv6 Packets egressed by + // Netstack. + DefaultTTL = 64 ) type endpoint struct { @@ -47,11 +49,12 @@ type endpoint struct { linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache dispatcher stack.TransportDispatcher + protocol *protocol } // DefaultTTL is the default hop limit for this endpoint. func (e *endpoint) DefaultTTL() uint8 { - return 255 + return e.protocol.DefaultTTL() } // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus @@ -155,7 +158,12 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { // Close cleans up resources associated with the endpoint. func (*endpoint) Close() {} -type protocol struct{} +type protocol struct { + // defaultTTL is the current default TTL for the protocol. Only the + // uint8 portion of it is meaningful and it must be accessed + // atomically. + defaultTTL uint32 +} // Number returns the ipv6 protocol number. func (p *protocol) Number() tcpip.NetworkProtocolNumber { @@ -187,17 +195,40 @@ func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWi linkEP: linkEP, linkAddrCache: linkAddrCache, dispatcher: dispatcher, + protocol: p, }, nil } // SetOption implements NetworkProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch v := option.(type) { + case tcpip.DefaultTTLOption: + p.SetDefaultTTL(uint8(v)) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } } // Option implements NetworkProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch v := option.(type) { + case *tcpip.DefaultTTLOption: + *v = tcpip.DefaultTTLOption(p.DefaultTTL()) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// SetDefaultTTL sets the default TTL for endpoints created with this protocol. +func (p *protocol) SetDefaultTTL(ttl uint8) { + atomic.StoreUint32(&p.defaultTTL, uint32(ttl)) +} + +// DefaultTTL returns the default TTL for endpoints created with this protocol. +func (p *protocol) DefaultTTL() uint8 { + return uint8(atomic.LoadUint32(&p.defaultTTL)) } // calculateMTU calculates the network-layer payload MTU based on the link-layer @@ -212,5 +243,5 @@ func calculateMTU(mtu uint32) uint32 { // NewProtocol returns an IPv6 network protocol. func NewProtocol() stack.NetworkProtocol { - return &protocol{} + return &protocol{defaultTTL: DefaultTTL} } diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 0b09e6517..22db619a1 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -154,11 +154,15 @@ func (r *Route) IsResolutionRequired() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, useDefaultTTL bool) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } + if useDefaultTTL { + ttl = r.DefaultTTL() + } + err := r.ref.ep.WritePacket(r, gso, hdr, payload, protocol, ttl, r.loop) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index a2e0a6e7b..8f063038b 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -310,7 +310,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123) + return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123 /* ttl */, false /* useDefaultTTL */) } func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 842a16277..db290c404 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -75,7 +75,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), fakeTransNumber, 123); err != nil { + if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), fakeTransNumber, 123 /* ttl */, false /* useDefaultTTL */); err != nil { return 0, nil, err } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 70e7575f5..a35d6562e 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -553,6 +553,12 @@ type ModerateReceiveBufferOption bool // Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option. type MaxSegOption int +// TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop +// limit value for unicast messages. The default is protocol specific. +// +// A zero value indicates the default. +type TTLOption uint8 + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default // TTL value for multicast messages. The default is 1. type MulticastTTLOption uint8 @@ -594,6 +600,10 @@ type OutOfBandInlineOption int // datagram sockets are allowed to send packets to a broadcast address. type BroadcastOption int +// DefaultTTLOption is used by stack.(*Stack).NetworkProtocolOption to specify +// a default TTL. +type DefaultTTLOption uint8 + // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index a3a910d41..e17e737e4 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -84,6 +84,7 @@ type endpoint struct { // NIC. regNICID tcpip.NICID route stack.Route `state:"manual"` + ttl uint8 } func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { @@ -296,10 +297,10 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c switch e.netProto { case header.IPv4ProtocolNumber: - err = send4(route, e.id.LocalPort, v) + err = send4(route, e.id.LocalPort, v, e.ttl) case header.IPv6ProtocolNumber: - err = send6(route, e.id.LocalPort, v) + err = send6(route, e.id.LocalPort, v, e.ttl) } if err != nil { @@ -314,8 +315,15 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } -// SetSockOpt sets a socket option. Currently not supported. +// SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(o) + e.mu.Unlock() + } + return nil } @@ -362,12 +370,18 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = 0 return nil + case *tcpip.TTLOption: + e.rcvMu.Lock() + *o = tcpip.TTLOption(e.ttl) + e.rcvMu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } } -func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { +func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { if len(data) < header.ICMPv4MinimumSize { return tcpip.ErrInvalidEndpointState } @@ -389,10 +403,10 @@ func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) + return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, ttl, ttl == 0 /* useDefaultTTL */) } -func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { +func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { if len(data) < header.ICMPv6EchoMinimumSize { return tcpip.ErrInvalidEndpointState } @@ -412,7 +426,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { icmpv6.SetChecksum(0) icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0))) - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv6ProtocolNumber, r.DefaultTTL()) + return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv6ProtocolNumber, ttl, ttl == 0 /* useDefaultTTL */) } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index a02731a5d..8ae132efa 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -332,7 +332,7 @@ func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, break } hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), ep.transProto, route.DefaultTTL()); err != nil { + if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), ep.transProto, 0, true /* useDefaultTTL */); err != nil { return 0, nil, err } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 3ae4a5426..b8b4bcee8 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -439,7 +439,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { TSEcr: opts.TSVal, MSS: uint16(mss), } - sendSynTCP(&s.route, s.id, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) + sendSynTCP(&s.route, s.id, e.ttl, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 21038a65a..1d6e7f5f3 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -238,6 +238,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { h.state = handshakeSynRcvd h.ep.mu.Lock() h.ep.state = StateSynRecv + ttl := h.ep.ttl h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ WS: int(h.effectiveRcvWndScale()), @@ -251,7 +252,10 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { SACKPermitted: rcvSynOpts.SACKPermitted, MSS: h.ep.amss, } - sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + if ttl == 0 { + ttl = s.route.DefaultTTL() + } + sendSynTCP(&s.route, h.ep.id, ttl, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil } @@ -296,7 +300,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } - sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + sendSynTCP(&s.route, h.ep.id, h.ep.ttl, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil } @@ -460,7 +464,7 @@ func (h *handshake) execute() *tcpip.Error { synOpts.WS = -1 } } - sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + sendSynTCP(&h.ep.route, h.ep.id, h.ep.ttl, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) for h.state != handshakeCompleted { switch index, _ := s.Fetch(true); index { case wakerForResend: @@ -469,7 +473,7 @@ func (h *handshake) execute() *tcpip.Error { return tcpip.ErrTimeout } rt.Reset(timeOut) - sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + sendSynTCP(&h.ep.route, h.ep.id, h.ep.ttl, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) case wakerForNotification: n := h.ep.fetchNotifications() @@ -579,9 +583,9 @@ func makeSynOptions(opts header.TCPSynOptions) []byte { return options[:offset] } -func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { +func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { options := makeSynOptions(opts) - err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options, nil) + err := sendTCP(r, id, buffer.VectorisedView{}, ttl, flags, seq, ack, rcvWnd, options, nil) putOptions(options) return err } @@ -629,7 +633,7 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise r.Stats().TCP.ResetsSent.Increment() } - return r.WritePacket(gso, hdr, data, ProtocolNumber, ttl) + return r.WritePacket(gso, hdr, data, ProtocolNumber, ttl, ttl == 0 /* useDefaultTTL */) } // makeOptions makes an options slice. @@ -678,7 +682,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) - err := sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options, e.gso) + err := sendTCP(&e.route, e.id, data, e.ttl, flags, seq, ack, rcvWnd, options, e.gso) putOptions(options) return err } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index f9d5e0085..83d92b3e1 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -228,6 +228,7 @@ type endpoint struct { isRegistered bool boundNICID tcpip.NICID `state:"manual"` route stack.Route `state:"manual"` + ttl uint8 v6only bool isConnectNotified bool // TCP should never broadcast but Linux nevertheless supports enabling/ @@ -1116,6 +1117,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.v6only = v != 0 return nil + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(v) + e.mu.Unlock() + return nil + case tcpip.KeepaliveEnabledOption: e.keepalive.Lock() e.keepalive.enabled = v != 0 @@ -1313,6 +1320,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.TTLOption: + e.mu.Lock() + *o = tcpip.TTLOption(e.ttl) + e.mu.Unlock() + return nil + case *tcpip.TCPInfoOption: *o = tcpip.TCPInfoOption{} e.mu.RLock() diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 089826a88..a86123829 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1752,6 +1752,34 @@ func TestSendGreaterThanMTU(t *testing.T) { testBrokenUpWrite(t, c, maxPayload) } +func TestSetTTL(t *testing.T) { + for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { + t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { + c := context.New(t, 65535) + defer c.Cleanup() + + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := c.EP.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { + t.Fatalf("SetSockOpt failed: %v", err) + } + + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetPacket() + + checker.IPv4(t, b, checker.TTL(wantTTL)) + }) + } +} + func TestActiveSendMSSLessThanMTU(t *testing.T) { const maxPayload = 100 c := context.New(t, 65535) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 52f5af777..8ae83437e 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -83,6 +83,7 @@ type endpoint struct { route stack.Route `state:"manual"` dstPort uint16 v6only bool + ttl uint8 multicastTTL uint8 multicastAddr tcpip.Address multicastNICID tcpip.NICID @@ -374,12 +375,16 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, tcpip.ErrMessageTooLong } - ttl := route.DefaultTTL() + ttl := e.ttl + useDefaultTTL := ttl == 0 + if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) { ttl = e.multicastTTL + // Multicast allows a 0 TTL. + useDefaultTTL = false } - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil { + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl, useDefaultTTL); err != nil { return 0, nil, err } return int64(len(v)), nil, nil @@ -414,6 +419,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.v6only = v != 0 + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(v) + e.mu.Unlock() + case tcpip.MulticastTTLOption: e.mu.Lock() e.multicastTTL = uint8(v) @@ -628,6 +638,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.TTLOption: + e.mu.Lock() + *o = tcpip.TTLOption(e.ttl) + e.mu.Unlock() + return nil + case *tcpip.MulticastTTLOption: e.mu.Lock() *o = tcpip.MulticastTTLOption(e.multicastTTL) @@ -694,7 +710,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8) *tcpip.Error { +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool) *tcpip.Error { // Allocate a buffer for the UDP header. hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) @@ -720,7 +736,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // Track count of packets sent. r.Stats().UDP.PacketsSent.Increment() - return r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl) + return r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl, useDefaultTTL) } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index f5cc932dd..0b24bbc85 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -130,7 +130,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv4DstUnreachable) pkt.SetCode(header.ICMPv4PortUnreachable) pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) - r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv4ProtocolNumber, r.DefaultTTL()) + r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv4ProtocolNumber, 0, true /* useDefaultTTL */) case header.IPv6AddressSize: if !r.Stack().AllowICMPMessage() { @@ -164,7 +164,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv6DstUnreachable) pkt.SetCode(header.ICMPv6PortUnreachable) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) - r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv6ProtocolNumber, r.DefaultTTL()) + r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv6ProtocolNumber, 0, true /* useDefaultTTL */) } return true } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index d1ba8d2ce..faa728b68 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1207,6 +1207,39 @@ func TestTTL(t *testing.T) { } } +func TestSetTTL(t *testing.T) { + for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { + t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + + var p stack.NetworkProtocol + if flow.isV4() { + p = ipv4.NewProtocol() + } else { + p = ipv6.NewProtocol() + } + ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) + if err != nil { + t.Fatal(err) + } + ep.Close() + + testWrite(c, flow, checker.TTL(wantTTL)) + }) + } + }) + } +} + func TestMulticastInterfaceOption(t *testing.T) { for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { |