summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2020-11-13 13:10:51 -0800
committergVisor bot <gvisor-bot@google.com>2020-11-13 13:13:21 -0800
commit6c0f53002a7f3a518befbe667d308c3d64cc9a59 (patch)
tree50119065f7d1e050034d7c875ef5816d19b20903 /pkg/tcpip
parentd5e17d2dbc2809c6d70153f0d4c996eff899e69d (diff)
Decrement TTL/Hop Limit when forwarding IP packets
If the packet must no longer be forwarded because its TTL/Hop Limit reaches 0, send an ICMP Time Exceeded error to the source. Required as per relevant RFCs. See comments in code for RFC references. Fixes #1085 Tests: - ipv4_test.TestForwarding - ipv6.TestForwarding PiperOrigin-RevId: 342323610
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/checker/checker.go12
-rw-r--r--pkg/tcpip/header/ipv6.go9
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go33
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go29
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go157
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go32
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go29
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go158
8 files changed, 443 insertions, 16 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 8868cf4e3..81f762e10 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -904,6 +904,12 @@ func ICMPv4Payload(want []byte) TransportChecker {
t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
payload := icmpv4.Payload()
+
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(want) == 0 && len(payload) == 0 {
+ return
+ }
+
if diff := cmp.Diff(want, payload); diff != "" {
t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
}
@@ -994,6 +1000,12 @@ func ICMPv6Payload(want []byte) TransportChecker {
t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
}
payload := icmpv6.Payload()
+
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(want) == 0 && len(payload) == 0 {
+ return
+ }
+
if diff := cmp.Diff(want, payload); diff != "" {
t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 4e7e5f76a..55d09355a 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -54,7 +54,7 @@ type IPv6Fields struct {
// NextHeader is the "next header" field of an IPv6 packet.
NextHeader uint8
- // HopLimit is the "hop limit" field of an IPv6 packet.
+ // HopLimit is the "Hop Limit" field of an IPv6 packet.
HopLimit uint8
// SrcAddr is the "source ip address" of an IPv6 packet.
@@ -171,7 +171,7 @@ func (b IPv6) PayloadLength() uint16 {
return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:])
}
-// HopLimit returns the value of the "hop limit" field of the ipv6 header.
+// HopLimit returns the value of the "Hop Limit" field of the ipv6 header.
func (b IPv6) HopLimit() uint8 {
return b[hopLimit]
}
@@ -236,6 +236,11 @@ func (b IPv6) SetDestinationAddress(addr tcpip.Address) {
copy(b[v6DstAddr:][:IPv6AddressSize], addr)
}
+// SetHopLimit sets the value of the "Hop Limit" field.
+func (b IPv6) SetHopLimit(v uint8) {
+ b[hopLimit] = v
+}
+
// SetNextHeader sets the value of the "next header" field of the ipv6 header.
func (b IPv6) SetNextHeader(v uint8) {
b[IPv6NextHeaderOffset] = v
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 9b5e37fee..58a19e74a 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -290,6 +290,13 @@ type icmpReasonProtoUnreachable struct{}
func (*icmpReasonProtoUnreachable) isICMPReason() {}
+// icmpReasonTTLExceeded is an error where a packet's time to live exceeded in
+// transit to its final destination, as per RFC 792 page 6, Time Exceeded
+// Message.
+type icmpReasonTTLExceeded struct{}
+
+func (*icmpReasonTTLExceeded) isICMPReason() {}
+
// icmpReasonReassemblyTimeout is an error where insufficient fragments are
// received to complete reassembly of a packet within a configured time after
// the reception of the first-arriving fragment of that packet.
@@ -342,11 +349,31 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
return nil
}
+ // If we hit a TTL Exceeded error, then we know we are operating as a router.
+ // As per RFC 792 page 6, Time Exceeded Message,
+ //
+ // If the gateway processing a datagram finds the time to live field
+ // is zero it must discard the datagram. The gateway may also notify
+ // the source host via the time exceeded message.
+ //
+ // ...
+ //
+ // Code 0 may be received from a gateway. ...
+ //
+ // Note, Code 0 is the TTL exceeded error.
+ //
+ // If we are operating as a router/gateway, don't use the packet's destination
+ // address as the response's source address as we should not not own the
+ // destination address of a packet we are forwarding.
+ localAddr := origIPHdrDst
+ if _, ok := reason.(*icmpReasonTTLExceeded); ok {
+ localAddr = ""
+ }
// Even if we were able to receive a packet from some remote, we may not have
// a route to it - the remote may be blocked via routing rules. We must always
// consult our routing table and find a route to the remote before sending any
// packet.
- route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
+ route, err := p.stack.FindRoute(pkt.NICID, localAddr, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
if err != nil {
return err
}
@@ -454,6 +481,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
counter = sent.DstUnreachable
+ case *icmpReasonTTLExceeded:
+ icmpHdr.SetType(header.ICMPv4TimeExceeded)
+ icmpHdr.SetCode(header.ICMPv4TTLExceeded)
+ counter = sent.TimeExceeded
case *icmpReasonReassemblyTimeout:
icmpHdr.SetType(header.ICMPv4TimeExceeded)
icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout)
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 0c828004a..b4f21d61e 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -485,6 +485,16 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// forwardPacket attempts to forward a packet to its final destination.
func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
h := header.IPv4(pkt.NetworkHeader().View())
+ ttl := h.TTL()
+ if ttl == 0 {
+ // As per RFC 792 page 6, Time Exceeded Message,
+ //
+ // If the gateway processing a datagram finds the time to live field
+ // is zero it must discard the datagram. The gateway may also notify
+ // the source host via the time exceeded message.
+ return e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt)
+ }
+
dstAddr := h.DestinationAddress()
// Check if the destination is owned by the stack.
@@ -503,13 +513,22 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
}
defer r.Release()
- // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+ // We need to do a deep copy of the IP packet because
+ // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
+ // not own it.
+ newHdr := header.IPv4(stack.PayloadSince(pkt.NetworkHeader()))
+
+ // As per RFC 791 page 30, Time to Live,
+ //
+ // This field must be decreased at each point that the internet header
+ // is processed to reflect the time spent processing the datagram.
+ // Even if no local information is available on the time actually
+ // spent, the field must be decremented by 1.
+ newHdr.SetTTL(ttl - 1)
+
return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
- // We need to do a deep copy of the IP packet because
- // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
- // not own it.
- Data: stack.PayloadSince(pkt.NetworkHeader()).ToVectorisedView(),
+ Data: buffer.View(newHdr).ToVectorisedView(),
}))
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index a7100b971..8b0d2d794 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -202,6 +202,163 @@ func TestIPv4EncodeOptions(t *testing.T) {
}
}
+func TestForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+ randomSequence = 123
+ randomIdent = 42
+ )
+
+ ipv4Addr1 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()),
+ PrefixLen: 8,
+ }
+ ipv4Addr2 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()),
+ PrefixLen: 8,
+ }
+ remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4())
+ remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4())
+
+ tests := []struct {
+ name string
+ TTL uint8
+ expectErrorICMP bool
+ }{
+ {
+ name: "TTL of zero",
+ TTL: 0,
+ expectErrorICMP: true,
+ },
+ {
+ name: "TTL of one",
+ TTL: 1,
+ expectErrorICMP: false,
+ },
+ {
+ name: "TTL of two",
+ TTL: 2,
+ expectErrorICMP: false,
+ },
+ {
+ name: "Max TTL",
+ TTL: math.MaxUint8,
+ expectErrorICMP: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
+ })
+ // We expect at most a single packet in response to our ICMP Echo Request.
+ e1 := channel.New(1, ipv4.MaxTotalSize, "")
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ ipv4ProtoAddr1 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr1}
+ if err := s.AddProtocolAddress(nicID1, ipv4ProtoAddr1); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv4ProtoAddr1, err)
+ }
+
+ e2 := channel.New(1, ipv4.MaxTotalSize, "")
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+ ipv4ProtoAddr2 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr2}
+ if err := s.AddProtocolAddress(nicID2, ipv4ProtoAddr2); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv4ProtoAddr2, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: ipv4Addr1.Subnet(),
+ NIC: nicID1,
+ },
+ {
+ Destination: ipv4Addr2.Subnet(),
+ NIC: nicID2,
+ },
+ })
+
+ if err := s.SetForwarding(header.IPv4ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err)
+ }
+
+ totalLen := uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize)
+ hdr := buffer.NewPrependable(int(totalLen))
+ icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmp.SetIdent(randomIdent)
+ icmp.SetSequence(randomSequence)
+ icmp.SetType(header.ICMPv4Echo)
+ icmp.SetCode(header.ICMPv4UnusedCode)
+ icmp.SetChecksum(0)
+ icmp.SetChecksum(^header.Checksum(icmp, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: totalLen,
+ Protocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: test.TTL,
+ SrcAddr: remoteIPv4Addr1,
+ DstAddr: remoteIPv4Addr2,
+ })
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+ requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
+
+ if test.expectErrorICMP {
+ reply, ok := e1.Read()
+ if !ok {
+ t.Fatal("expected ICMP TTL Exceeded packet through incoming NIC")
+ }
+
+ checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.SrcAddr(ipv4Addr1.Address),
+ checker.DstAddr(remoteIPv4Addr1),
+ checker.TTL(ipv4.DefaultTTL),
+ checker.ICMPv4(
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Type(header.ICMPv4TimeExceeded),
+ checker.ICMPv4Code(header.ICMPv4TTLExceeded),
+ checker.ICMPv4Payload([]byte(hdr.View())),
+ ),
+ )
+
+ if n := e2.Drain(); n != 0 {
+ t.Fatalf("got e2.Drain() = %d, want = 0", n)
+ }
+ } else {
+ reply, ok := e2.Read()
+ if !ok {
+ t.Fatal("expected ICMP Echo packet through outgoing NIC")
+ }
+
+ checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.SrcAddr(remoteIPv4Addr1),
+ checker.DstAddr(remoteIPv4Addr2),
+ checker.TTL(test.TTL-1),
+ checker.ICMPv4(
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Type(header.ICMPv4Echo),
+ checker.ICMPv4Code(header.ICMPv4UnusedCode),
+ checker.ICMPv4Payload(nil),
+ ),
+ )
+
+ if n := e1.Drain(); n != 0 {
+ t.Fatalf("got e1.Drain() = %d, want = 0", n)
+ }
+ }
+ })
+ }
+}
+
// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and
// checks the response.
func TestIPv4Sanity(t *testing.T) {
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 8502b848c..8d788af80 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -750,6 +750,12 @@ type icmpReasonPortUnreachable struct{}
func (*icmpReasonPortUnreachable) isICMPReason() {}
+// icmpReasonHopLimitExceeded is an error where a packet's hop limit exceeded in
+// transit to its final destination, as per RFC 4443 section 3.3.
+type icmpReasonHopLimitExceeded struct{}
+
+func (*icmpReasonHopLimitExceeded) isICMPReason() {}
+
// icmpReasonReassemblyTimeout is an error where insufficient fragments are
// received to complete reassembly of a packet within a configured time after
// the reception of the first-arriving fragment of that packet.
@@ -794,11 +800,27 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
return nil
}
+ // If we hit a Hop Limit Exceeded error, then we know we are operating as a
+ // router. As per RFC 4443 section 3.3:
+ //
+ // If a router receives a packet with a Hop Limit of zero, or if a
+ // router decrements a packet's Hop Limit to zero, it MUST discard the
+ // packet and originate an ICMPv6 Time Exceeded message with Code 0 to
+ // the source of the packet. This indicates either a routing loop or
+ // too small an initial Hop Limit value.
+ //
+ // If we are operating as a router, do not use the packet's destination
+ // address as the response's source address as we should not own the
+ // destination address of a packet we are forwarding.
+ localAddr := origIPHdrDst
+ if _, ok := reason.(*icmpReasonHopLimitExceeded); ok {
+ localAddr = ""
+ }
// Even if we were able to receive a packet from some remote, we may not have
// a route to it - the remote may be blocked via routing rules. We must always
// consult our routing table and find a route to the remote before sending any
// packet.
- route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
+ route, err := p.stack.FindRoute(pkt.NICID, localAddr, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
if err != nil {
return err
}
@@ -811,8 +833,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
return nil
}
- network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
-
if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber {
// TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored.
// Unfortunately at this time ICMP Packets do not have a transport
@@ -830,6 +850,8 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
}
}
+ network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+
// As per RFC 4443 section 2.4
//
// (c) Every ICMPv6 error message (type < 128) MUST include
@@ -873,6 +895,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
icmpHdr.SetType(header.ICMPv6DstUnreachable)
icmpHdr.SetCode(header.ICMPv6PortUnreachable)
counter = sent.DstUnreachable
+ case *icmpReasonHopLimitExceeded:
+ icmpHdr.SetType(header.ICMPv6TimeExceeded)
+ icmpHdr.SetCode(header.ICMPv6HopLimitExceeded)
+ counter = sent.TimeExceeded
case *icmpReasonReassemblyTimeout:
icmpHdr.SetType(header.ICMPv6TimeExceeded)
icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout)
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 38a0633bd..7697ff987 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -645,6 +645,18 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// forwardPacket attempts to forward a packet to its final destination.
func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
h := header.IPv6(pkt.NetworkHeader().View())
+ hopLimit := h.HopLimit()
+ if hopLimit <= 1 {
+ // As per RFC 4443 section 3.3,
+ //
+ // If a router receives a packet with a Hop Limit of zero, or if a
+ // router decrements a packet's Hop Limit to zero, it MUST discard the
+ // packet and originate an ICMPv6 Time Exceeded message with Code 0 to
+ // the source of the packet. This indicates either a routing loop or
+ // too small an initial Hop Limit value.
+ return e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt)
+ }
+
dstAddr := h.DestinationAddress()
// Check if the destination is owned by the stack.
@@ -663,13 +675,20 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
}
defer r.Release()
- // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+ // We need to do a deep copy of the IP packet because
+ // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
+ // not own it.
+ newHdr := header.IPv6(stack.PayloadSince(pkt.NetworkHeader()))
+
+ // As per RFC 8200 section 3,
+ //
+ // Hop Limit 8-bit unsigned integer. Decremented by 1 by
+ // each node that forwards the packet.
+ newHdr.SetHopLimit(hopLimit - 1)
+
return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
- // We need to do a deep copy of the IP packet because
- // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
- // not own it.
- Data: stack.PayloadSince(pkt.NetworkHeader()).ToVectorisedView(),
+ Data: buffer.View(newHdr).ToVectorisedView(),
}))
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 1bfcdde25..a671d4bac 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -18,6 +18,7 @@ import (
"encoding/hex"
"fmt"
"math"
+ "net"
"testing"
"github.com/google/go-cmp/cmp"
@@ -2821,3 +2822,160 @@ func TestFragmentationErrors(t *testing.T) {
})
}
}
+
+func TestForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+ randomSequence = 123
+ randomIdent = 42
+ )
+
+ ipv6Addr1 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10::1").To16()),
+ PrefixLen: 64,
+ }
+ ipv6Addr2 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("11::1").To16()),
+ PrefixLen: 64,
+ }
+ remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16())
+ remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16())
+
+ tests := []struct {
+ name string
+ TTL uint8
+ expectErrorICMP bool
+ }{
+ {
+ name: "TTL of zero",
+ TTL: 0,
+ expectErrorICMP: true,
+ },
+ {
+ name: "TTL of one",
+ TTL: 1,
+ expectErrorICMP: true,
+ },
+ {
+ name: "TTL of two",
+ TTL: 2,
+ expectErrorICMP: false,
+ },
+ {
+ name: "TTL of three",
+ TTL: 3,
+ expectErrorICMP: false,
+ },
+ {
+ name: "Max TTL",
+ TTL: math.MaxUint8,
+ expectErrorICMP: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ })
+ // We expect at most a single packet in response to our ICMP Echo Request.
+ e1 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ ipv6ProtoAddr1 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr1}
+ if err := s.AddProtocolAddress(nicID1, ipv6ProtoAddr1); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv6ProtoAddr1, err)
+ }
+
+ e2 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+ ipv6ProtoAddr2 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr2}
+ if err := s.AddProtocolAddress(nicID2, ipv6ProtoAddr2); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv6ProtoAddr2, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: ipv6Addr1.Subnet(),
+ NIC: nicID1,
+ },
+ {
+ Destination: ipv6Addr2.Subnet(),
+ NIC: nicID2,
+ },
+ })
+
+ if err := s.SetForwarding(ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err)
+ }
+
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6MinimumSize)
+ icmp := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ icmp.SetIdent(randomIdent)
+ icmp.SetSequence(randomSequence)
+ icmp.SetType(header.ICMPv6EchoRequest)
+ icmp.SetCode(header.ICMPv6UnusedCode)
+ icmp.SetChecksum(0)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, remoteIPv6Addr1, remoteIPv6Addr2, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: header.ICMPv6MinimumSize,
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: test.TTL,
+ SrcAddr: remoteIPv6Addr1,
+ DstAddr: remoteIPv6Addr2,
+ })
+ requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ e1.InjectInbound(ProtocolNumber, requestPkt)
+
+ if test.expectErrorICMP {
+ reply, ok := e1.Read()
+ if !ok {
+ t.Fatal("expected ICMP Hop Limit Exceeded packet through incoming NIC")
+ }
+
+ checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.SrcAddr(ipv6Addr1.Address),
+ checker.DstAddr(remoteIPv6Addr1),
+ checker.TTL(DefaultTTL),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6TimeExceeded),
+ checker.ICMPv6Code(header.ICMPv6HopLimitExceeded),
+ checker.ICMPv6Payload([]byte(hdr.View())),
+ ),
+ )
+
+ if n := e2.Drain(); n != 0 {
+ t.Fatalf("got e2.Drain() = %d, want = 0", n)
+ }
+ } else {
+ reply, ok := e2.Read()
+ if !ok {
+ t.Fatal("expected ICMP Echo Request packet through outgoing NIC")
+ }
+
+ checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.SrcAddr(remoteIPv6Addr1),
+ checker.DstAddr(remoteIPv6Addr2),
+ checker.TTL(test.TTL-1),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoRequest),
+ checker.ICMPv6Code(header.ICMPv6UnusedCode),
+ checker.ICMPv6Payload(nil),
+ ),
+ )
+
+ if n := e1.Drain(); n != 0 {
+ t.Fatalf("got e1.Drain() = %d, want = 0", n)
+ }
+ }
+ })
+ }
+}