summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r--pkg/tcpip/network/BUILD2
-rw-r--r--pkg/tcpip/network/arp/BUILD7
-rw-r--r--pkg/tcpip/network/arp/arp.go65
-rw-r--r--pkg/tcpip/network/arp/arp_test.go29
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD18
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go16
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go10
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go10
-rw-r--r--pkg/tcpip/network/ip_test.go79
-rw-r--r--pkg/tcpip/network/ipv4/BUILD7
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go40
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go278
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go195
-rw-r--r--pkg/tcpip/network/ipv6/BUILD9
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go254
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go617
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go147
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go270
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go219
19 files changed, 1871 insertions, 401 deletions
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index f36f49453..9d16ff8c9 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_test")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index d95d44f56..e7617229b 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -6,9 +7,7 @@ go_library(
name = "arp",
srcs = ["arp.go"],
importpath = "gvisor.dev/gvisor/pkg/tcpip/network/arp",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index fd6395fc1..da8482509 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -16,9 +16,9 @@
// IPv4 addresses into link-local MAC addresses, and advertises IPv4
// addresses of its stack with the local network.
//
-// To use it in the networking stack, pass arp.ProtocolName as one of the
-// network protocols when calling stack.New. Then add an "arp" address to
-// every NIC on the stack that should respond to ARP requests. That is:
+// To use it in the networking stack, pass arp.NewProtocol() as one of the
+// network protocols when calling stack.New. Then add an "arp" address to every
+// NIC on the stack that should respond to ARP requests. That is:
//
// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil {
// // handle err
@@ -33,9 +33,6 @@ import (
)
const (
- // ProtocolName is the string representation of the ARP protocol name.
- ProtocolName = "arp"
-
// ProtocolNumber is the ARP protocol number.
ProtocolNumber = header.ARPProtocolNumber
@@ -45,7 +42,7 @@ const (
// endpoint implements stack.NetworkEndpoint.
type endpoint struct {
- nicid tcpip.NICID
+ nicID tcpip.NICID
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
}
@@ -61,7 +58,7 @@ func (e *endpoint) MTU() uint32 {
}
func (e *endpoint) NICID() tcpip.NICID {
- return e.nicid
+ return e.nicID
}
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
@@ -82,16 +79,21 @@ func (e *endpoint) MaxHeaderLength() uint16 {
func (e *endpoint) Close() {}
-func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.TransportProtocolNumber, uint8, stack.PacketLooping) *tcpip.Error {
+func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketLooping, tcpip.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
+// WritePackets implements stack.NetworkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams, stack.PacketLooping) (int, *tcpip.Error) {
+ return 0, tcpip.ErrNotSupported
+}
+
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
-func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
- v := vv.First()
+func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
+ v := pkt.Data.First()
h := header.ARP(v)
if !h.IsValid() {
return
@@ -100,19 +102,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
switch h.Op() {
case header.ARPRequest:
localAddr := tcpip.Address(h.ProtocolAddressTarget())
- if e.linkAddrCache.CheckLocalAddress(e.nicid, header.IPv4ProtocolNumber, localAddr) == 0 {
+ if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
return // we have no useful answer, ignore the request
}
hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize)
- pkt := header.ARP(hdr.Prepend(header.ARPSize))
- pkt.SetIPv4OverEthernet()
- pkt.SetOp(header.ARPReply)
- copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:])
- copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget())
- copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender())
- copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender())
- e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+ packet := header.ARP(hdr.Prepend(header.ARPSize))
+ packet.SetIPv4OverEthernet()
+ packet.SetOp(header.ARPReply)
+ copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:])
+ copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget())
+ copy(packet.HardwareAddressTarget(), h.HardwareAddressSender())
+ copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender())
+ e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{
+ Header: hdr,
+ })
+ fallthrough // also fill the cache from requests
case header.ARPReply:
+ addr := tcpip.Address(h.ProtocolAddressSender())
+ linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
+ e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr)
}
}
@@ -129,12 +137,12 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
}
-func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
if addrWithPrefix.Address != ProtocolAddress {
return nil, tcpip.ErrBadLocalAddress
}
return &endpoint{
- nicid: nicid,
+ nicID: nicID,
linkEP: sender,
linkAddrCache: linkAddrCache,
}, nil
@@ -159,7 +167,9 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
copy(h.ProtocolAddressSender(), localAddr)
copy(h.ProtocolAddressTarget(), addr)
- return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+ return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{
+ Header: hdr,
+ })
}
// ResolveStaticAddress implements stack.LinkAddressResolver.
@@ -200,8 +210,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
-func init() {
- stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
- return &protocol{}
- })
+// NewProtocol returns an ARP network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 4c4b54469..8e6048a21 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -44,14 +44,19 @@ type testContext struct {
}
func newTestContext(t *testing.T) *testContext {
- s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{icmp.ProtocolName4}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()},
+ })
const defaultMTU = 65536
- id, linkEP := channel.New(256, defaultMTU, stackLinkAddr)
+ ep := channel.New(256, defaultMTU, stackLinkAddr)
+ wep := stack.LinkEndpoint(ep)
+
if testing.Verbose() {
- id = sniffer.New(id)
+ wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, wep); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -73,7 +78,7 @@ func newTestContext(t *testing.T) *testContext {
return &testContext{
t: t,
s: s,
- linkEP: linkEP,
+ linkEP: ep,
}
}
@@ -97,19 +102,21 @@ func TestDirectRequest(t *testing.T) {
inject := func(addr tcpip.Address) {
copy(h.ProtocolAddressTarget(), addr)
- c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView())
+ c.linkEP.InjectInbound(arp.ProtocolNumber, tcpip.PacketBuffer{
+ Data: v.ToVectorisedView(),
+ })
}
for i, address := range []tcpip.Address{stackAddr1, stackAddr2} {
t.Run(strconv.Itoa(i), func(t *testing.T) {
inject(address)
- pkt := <-c.linkEP.C
- if pkt.Proto != arp.ProtocolNumber {
- t.Fatalf("expected ARP response, got network protocol number %d", pkt.Proto)
+ pi := <-c.linkEP.C
+ if pi.Proto != arp.ProtocolNumber {
+ t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto)
}
- rep := header.ARP(pkt.Header)
+ rep := header.ARP(pi.Pkt.Header.View())
if !rep.IsValid() {
- t.Fatalf("invalid ARP response len(pkt.Header)=%d", len(pkt.Header))
+ t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength())
}
if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
t.Errorf("got HardwareAddressSender = %s, want = %s", got, want)
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index 118bfc763..acf1e022c 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -1,7 +1,8 @@
-package(licenses = ["notice"])
-
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
go_template_instance(
name = "reassembler_list",
@@ -24,9 +25,10 @@ go_library(
"reassembler_list.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation",
- visibility = ["//:sandbox"],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/log",
+ "//pkg/tcpip",
"//pkg/tcpip/buffer",
],
)
@@ -42,11 +44,3 @@ go_test(
embed = [":fragmentation"],
deps = ["//pkg/tcpip/buffer"],
)
-
-filegroup(
- name = "autogen",
- srcs = [
- "reassembler_list.go",
- ],
- visibility = ["//:sandbox"],
-)
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index 1628a82be..6da5238ec 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -17,6 +17,7 @@
package fragmentation
import (
+ "fmt"
"log"
"sync"
"time"
@@ -82,7 +83,7 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t
// Process processes an incoming fragment belonging to an ID
// and returns a complete packet when all the packets belonging to that ID have been received.
-func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool) {
+func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) {
f.mu.Lock()
r, ok := f.reassemblers[id]
if ok && r.tooOld(f.timeout) {
@@ -97,8 +98,15 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf
}
f.mu.Unlock()
- res, done, consumed := r.process(first, last, more, vv)
-
+ res, done, consumed, err := r.process(first, last, more, vv)
+ if err != nil {
+ // We probably got an invalid sequence of fragments. Just
+ // discard the reassembler and move on.
+ f.mu.Lock()
+ f.release(r)
+ f.mu.Unlock()
+ return buffer.VectorisedView{}, false, fmt.Errorf("fragmentation processing error: %v", err)
+ }
f.mu.Lock()
f.size += consumed
if done {
@@ -114,7 +122,7 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf
}
}
f.mu.Unlock()
- return res, done
+ return res, done, nil
}
func (f *Fragmentation) release(r *reassembler) {
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index 799798544..72c0f53be 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -83,7 +83,10 @@ func TestFragmentationProcess(t *testing.T) {
t.Run(c.comment, func(t *testing.T) {
f := NewFragmentation(1024, 512, DefaultReassembleTimeout)
for i, in := range c.in {
- vv, done := f.Process(in.id, in.first, in.last, in.more, in.vv)
+ vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv)
+ if err != nil {
+ t.Fatalf("f.Process(%+v, %+d, %+d, %t, %+v) failed: %v", in.id, in.first, in.last, in.more, in.vv, err)
+ }
if !reflect.DeepEqual(vv, c.out[i].vv) {
t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv)
}
@@ -114,7 +117,10 @@ func TestReassemblingTimeout(t *testing.T) {
time.Sleep(2 * timeout)
// Send another fragment that completes a packet.
// However, no packet should be reassembled because the fragment arrived after the timeout.
- _, done := f.Process(0, 1, 1, false, vv(1, "1"))
+ _, done, err := f.Process(0, 1, 1, false, vv(1, "1"))
+ if err != nil {
+ t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err)
+ }
if done {
t.Errorf("Fragmentation does not respect the reassembling timeout.")
}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index 8037f734b..9e002e396 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -78,7 +78,7 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) bool {
return used
}
-func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int) {
+func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int, error) {
r.mu.Lock()
defer r.mu.Unlock()
consumed := 0
@@ -86,7 +86,7 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise
// A concurrent goroutine might have already reassembled
// the packet and emptied the heap while this goroutine
// was waiting on the mutex. We don't have to do anything in this case.
- return buffer.VectorisedView{}, false, consumed
+ return buffer.VectorisedView{}, false, consumed, nil
}
if r.updateHoles(first, last, more) {
// We store the incoming packet only if it filled some holes.
@@ -96,13 +96,13 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise
}
// Check if all the holes have been deleted and we are ready to reassamble.
if r.deleted < len(r.holes) {
- return buffer.VectorisedView{}, false, consumed
+ return buffer.VectorisedView{}, false, consumed, nil
}
res, err := r.heap.reassemble()
if err != nil {
- panic(fmt.Sprintf("reassemble failed with: %v. There is probably a bug in the code handling the holes.", err))
+ return buffer.VectorisedView{}, false, consumed, fmt.Errorf("fragment reassembly failed: %v", err)
}
- return res, true, consumed
+ return res, true, consumed, nil
}
func (r *reassembler) tooOld(timeout time.Duration) bool {
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 4b3bd74fa..4144a7837 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -96,16 +96,16 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff
// DeliverTransportPacket is called by network endpoints after parsing incoming
// packets. This is used by the test object to verify that the results of the
// parsing are expected.
-func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) {
- t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress)
+func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) {
+ t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress)
t.dataCalls++
}
// DeliverTransportControlPacket is called by network endpoints after parsing
// incoming control (ICMP) packets. This is used by the test object to verify
// that the results of the parsing are expected.
-func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
- t.checkValues(trans, vv, remote, local)
+func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+ t.checkValues(trans, pkt.Data, remote, local)
if typ != t.typ {
t.t.Errorf("typ = %v, want %v", typ, t.typ)
}
@@ -144,32 +144,47 @@ func (*testObject) LinkAddress() tcpip.LinkAddress {
return ""
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*testObject) Wait() {}
+
// WritePacket is called by network endpoints after producing a packet and
// writing it to the link endpoint. This is used by the test object to verify
// that the produced packet is as expected.
-func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
var prot tcpip.TransportProtocolNumber
var srcAddr tcpip.Address
var dstAddr tcpip.Address
if t.v4 {
- h := header.IPv4(hdr.View())
+ h := header.IPv4(pkt.Header.View())
prot = tcpip.TransportProtocolNumber(h.Protocol())
srcAddr = h.SourceAddress()
dstAddr = h.DestinationAddress()
} else {
- h := header.IPv6(hdr.View())
+ h := header.IPv6(pkt.Header.View())
prot = tcpip.TransportProtocolNumber(h.NextHeader())
srcAddr = h.SourceAddress()
dstAddr = h.DestinationAddress()
}
- t.checkValues(prot, payload, srcAddr, dstAddr)
+ t.checkValues(prot, pkt.Data, srcAddr, dstAddr)
return nil
}
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ })
s.CreateNIC(1, loopback.New())
s.AddAddress(1, ipv4.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
@@ -182,7 +197,10 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
}
func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
- s := stack.New([]string{ipv6.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ })
s.CreateNIC(1, loopback.New())
s.AddAddress(1, ipv6.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
@@ -221,7 +239,10 @@ func TestIPv4Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil {
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: payload.ToVectorisedView(),
+ }); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
}
@@ -261,7 +282,9 @@ func TestIPv4Receive(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- ep.HandlePacket(&r, view.ToVectorisedView())
+ ep.HandlePacket(&r, tcpip.PacketBuffer{
+ Data: view.ToVectorisedView(),
+ })
if o.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
}
@@ -349,7 +372,9 @@ func TestIPv4ReceiveControl(t *testing.T) {
o.extra = c.expectedExtra
vv := view[:len(view)-c.trunc].ToVectorisedView()
- ep.HandlePacket(&r, vv)
+ ep.HandlePacket(&r, tcpip.PacketBuffer{
+ Data: vv,
+ })
if want := c.expectedCount; o.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
}
@@ -412,13 +437,17 @@ func TestIPv4FragmentationReceive(t *testing.T) {
}
// Send first segment.
- ep.HandlePacket(&r, frag1.ToVectorisedView())
+ ep.HandlePacket(&r, tcpip.PacketBuffer{
+ Data: frag1.ToVectorisedView(),
+ })
if o.dataCalls != 0 {
t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls)
}
// Send second segment.
- ep.HandlePacket(&r, frag2.ToVectorisedView())
+ ep.HandlePacket(&r, tcpip.PacketBuffer{
+ Data: frag2.ToVectorisedView(),
+ })
if o.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
}
@@ -451,7 +480,10 @@ func TestIPv6Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil {
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: payload.ToVectorisedView(),
+ }); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
}
@@ -491,7 +523,9 @@ func TestIPv6Receive(t *testing.T) {
t.Fatalf("could not find route: %v", err)
}
- ep.HandlePacket(&r, view.ToVectorisedView())
+ ep.HandlePacket(&r, tcpip.PacketBuffer{
+ Data: view.ToVectorisedView(),
+ })
if o.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
}
@@ -501,6 +535,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
newUint16 := func(v uint16) *uint16 { return &v }
const mtu = 0xffff
+ const outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa"
cases := []struct {
name string
expectedCount int
@@ -552,7 +587,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: 20,
- SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
+ SrcAddr: outerSrcAddr,
DstAddr: localIpv6Addr,
})
@@ -599,8 +634,12 @@ func TestIPv6ReceiveControl(t *testing.T) {
o.typ = c.expectedTyp
o.extra = c.expectedExtra
- vv := view[:len(view)-c.trunc].ToVectorisedView()
- ep.HandlePacket(&r, vv)
+ // Set ICMPv6 checksum.
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{}))
+
+ ep.HandlePacket(&r, tcpip.PacketBuffer{
+ Data: view[:len(view)-c.trunc].ToVectorisedView(),
+ })
if want := c.expectedCount; o.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index be84fa63d..aeddfcdd4 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -9,9 +10,7 @@ go_library(
"ipv4.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv4",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index a25756443..32bf39e43 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -15,6 +15,7 @@
package ipv4
import (
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -24,8 +25,8 @@ import (
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
// packet.
-func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
- h := header.IPv4(vv.First())
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+ h := header.IPv4(pkt.Data.First())
// We don't use IsValid() here because ICMP only requires that the IP
// header plus 8 bytes of the transport header be included. So it's
@@ -39,7 +40,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.
}
hlen := int(h.HeaderLength())
- if vv.Size() < hlen || h.FragmentOffset() != 0 {
+ if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 {
// We won't be able to handle this if it doesn't contain the
// full IPv4 header, or if it's a fragment not at offset 0
// (because it won't have the transport header).
@@ -47,15 +48,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.
}
// Skip the ip header, then deliver control message.
- vv.TrimFront(hlen)
+ pkt.Data.TrimFront(hlen)
p := h.TransportProtocol()
- e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
-func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
+func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) {
stats := r.Stats()
received := stats.ICMP.V4PacketsReceived
- v := vv.First()
+ v := pkt.Data.First()
if len(v) < header.ICMPv4MinimumSize {
received.Invalid.Increment()
return
@@ -73,20 +74,23 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
// checksum. We'll have to reset this before we hand the packet
// off.
h.SetChecksum(0)
- gotChecksum := ^header.ChecksumVV(vv, 0 /* initial */)
+ gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
if gotChecksum != wantChecksum {
// It's possible that a raw socket expects to receive this.
h.SetChecksum(wantChecksum)
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
received.Invalid.Increment()
return
}
// It's possible that a raw socket expects to receive this.
h.SetChecksum(wantChecksum)
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, tcpip.PacketBuffer{
+ Data: pkt.Data.Clone(nil),
+ NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...),
+ })
- vv := vv.Clone(nil)
+ vv := pkt.Data.Clone(nil)
vv.TrimFront(header.ICMPv4MinimumSize)
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize)
pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
@@ -95,7 +99,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
pkt.SetChecksum(0)
pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0)))
sent := stats.ICMP.V4PacketsSent
- if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL()); err != nil {
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: vv,
+ TransportHeader: buffer.View(pkt),
+ }); err != nil {
sent.Dropped.Increment()
return
}
@@ -104,19 +112,19 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
case header.ICMPv4EchoReply:
received.EchoReply.Increment()
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
case header.ICMPv4DstUnreachable:
received.DstUnreachable.Increment()
- vv.TrimFront(header.ICMPv4MinimumSize)
+ pkt.Data.TrimFront(header.ICMPv4MinimumSize)
switch h.Code() {
case header.ICMPv4PortUnreachable:
- e.handleControl(stack.ControlPortUnreachable, 0, vv)
+ e.handleControl(stack.ControlPortUnreachable, 0, pkt)
case header.ICMPv4FragmentationNeeded:
mtu := uint32(h.MTU())
- e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
+ e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt)
}
case header.ICMPv4SrcQuench:
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index b7a06f525..e645cf62c 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -14,9 +14,9 @@
// Package ipv4 contains the implementation of the ipv4 network protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing ipv4.ProtocolName (or "ipv4") as one of the
-// network protocols when calling stack.New(). Then endpoints can be created
-// by passing ipv4.ProtocolNumber as the network protocol number when calling
+// activated on the stack by passing ipv4.NewProtocol() as one of the network
+// protocols when calling stack.New(). Then endpoints can be created by passing
+// ipv4.ProtocolNumber as the network protocol number when calling
// Stack.NewEndpoint().
package ipv4
@@ -32,9 +32,6 @@ import (
)
const (
- // ProtocolName is the string representation of the ipv4 protocol name.
- ProtocolName = "ipv4"
-
// ProtocolNumber is the ipv4 protocol number.
ProtocolNumber = header.IPv4ProtocolNumber
@@ -42,28 +39,33 @@ const (
// TotalLength field of the ipv4 header.
MaxTotalSize = 0xffff
+ // DefaultTTL is the default time-to-live value for this endpoint.
+ DefaultTTL = 64
+
// buckets is the number of identifier buckets.
buckets = 2048
)
type endpoint struct {
- nicid tcpip.NICID
+ nicID tcpip.NICID
id stack.NetworkEndpointID
prefixLen int
linkEP stack.LinkEndpoint
dispatcher stack.TransportDispatcher
fragmentation *fragmentation.Fragmentation
+ protocol *protocol
}
// NewEndpoint creates a new ipv4 endpoint.
-func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
e := &endpoint{
- nicid: nicid,
+ nicID: nicID,
id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
prefixLen: addrWithPrefix.PrefixLen,
linkEP: linkEP,
dispatcher: dispatcher,
fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ protocol: p,
}
return e, nil
@@ -71,7 +73,7 @@ func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWi
// DefaultTTL is the default time-to-live value for this endpoint.
func (e *endpoint) DefaultTTL() uint8 {
- return 255
+ return e.protocol.DefaultTTL()
}
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
@@ -87,7 +89,7 @@ func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
// NICID returns the ID of the NIC this endpoint belongs to.
func (e *endpoint) NICID() tcpip.NICID {
- return e.nicid
+ return e.nicID
}
// ID returns the ipv4 endpoint ID.
@@ -115,13 +117,14 @@ func (e *endpoint) GSOMaxSize() uint32 {
}
// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to
-// write. It assumes that the IP header is entirely in hdr but does not assume
-// that only the IP header is in hdr. It assumes that the input packet's stated
-// length matches the length of the hdr+payload. mtu includes the IP header and
-// options. This does not support the DontFragment IP flag.
-func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, mtu int) *tcpip.Error {
+// write. It assumes that the IP header is entirely in pkt.Header but does not
+// assume that only the IP header is in pkt.Header. It assumes that the input
+// packet's stated length matches the length of the header+payload. mtu
+// includes the IP header and options. This does not support the DontFragment
+// IP flag.
+func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt tcpip.PacketBuffer) *tcpip.Error {
// This packet is too big, it needs to be fragmented.
- ip := header.IPv4(hdr.View())
+ ip := header.IPv4(pkt.Header.View())
flags := ip.Flags()
// Update mtu to take into account the header, which will exist in all
@@ -135,122 +138,167 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buff
outerMTU := innerMTU + int(ip.HeaderLength())
offset := ip.FragmentOffset()
- originalAvailableLength := hdr.AvailableLength()
+ originalAvailableLength := pkt.Header.AvailableLength()
for i := 0; i < n; i++ {
// Where possible, the first fragment that is sent has the same
- // hdr.UsedLength() as the input packet. The link-layer endpoint may depends
- // on this for looking at, eg, L4 headers.
+ // pkt.Header.UsedLength() as the input packet. The link-layer
+ // endpoint may depend on this for looking at, eg, L4 headers.
h := ip
if i > 0 {
- hdr = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength)
- h = header.IPv4(hdr.Prepend(int(ip.HeaderLength())))
+ pkt.Header = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength)
+ h = header.IPv4(pkt.Header.Prepend(int(ip.HeaderLength())))
copy(h, ip[:ip.HeaderLength()])
}
if i != n-1 {
h.SetTotalLength(uint16(outerMTU))
h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset)
} else {
- h.SetTotalLength(uint16(h.HeaderLength()) + uint16(payload.Size()))
+ h.SetTotalLength(uint16(h.HeaderLength()) + uint16(pkt.Data.Size()))
h.SetFlagsFragmentOffset(flags, offset)
}
h.SetChecksum(0)
h.SetChecksum(^h.CalculateChecksum())
offset += uint16(innerMTU)
if i > 0 {
- newPayload := payload.Clone([]buffer.View{})
+ newPayload := pkt.Data.Clone(nil)
newPayload.CapLength(innerMTU)
- if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil {
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{
+ Header: pkt.Header,
+ Data: newPayload,
+ NetworkHeader: buffer.View(h),
+ }); err != nil {
return err
}
r.Stats().IP.PacketsSent.Increment()
- payload.TrimFront(newPayload.Size())
+ pkt.Data.TrimFront(newPayload.Size())
continue
}
- // Special handling for the first fragment because it comes from the hdr.
- if outerMTU >= hdr.UsedLength() {
- // This fragment can fit all of hdr and possibly some of payload, too.
- newPayload := payload.Clone([]buffer.View{})
- newPayloadLength := outerMTU - hdr.UsedLength()
+ // Special handling for the first fragment because it comes
+ // from the header.
+ if outerMTU >= pkt.Header.UsedLength() {
+ // This fragment can fit all of pkt.Header and possibly
+ // some of pkt.Data, too.
+ newPayload := pkt.Data.Clone(nil)
+ newPayloadLength := outerMTU - pkt.Header.UsedLength()
newPayload.CapLength(newPayloadLength)
- if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil {
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{
+ Header: pkt.Header,
+ Data: newPayload,
+ NetworkHeader: buffer.View(h),
+ }); err != nil {
return err
}
r.Stats().IP.PacketsSent.Increment()
- payload.TrimFront(newPayloadLength)
+ pkt.Data.TrimFront(newPayloadLength)
} else {
- // The fragment is too small to fit all of hdr.
- startOfHdr := hdr
- startOfHdr.TrimBack(hdr.UsedLength() - outerMTU)
+ // The fragment is too small to fit all of pkt.Header.
+ startOfHdr := pkt.Header
+ startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU)
emptyVV := buffer.NewVectorisedView(0, []buffer.View{})
- if err := e.linkEP.WritePacket(r, gso, startOfHdr, emptyVV, ProtocolNumber); err != nil {
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{
+ Header: startOfHdr,
+ Data: emptyVV,
+ NetworkHeader: buffer.View(h),
+ }); err != nil {
return err
}
r.Stats().IP.PacketsSent.Increment()
- // Add the unused bytes of hdr into the payload that remains to be sent.
- restOfHdr := hdr.View()[outerMTU:]
+ // Add the unused bytes of pkt.Header into the pkt.Data
+ // that remains to be sent.
+ restOfHdr := pkt.Header.View()[outerMTU:]
tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)})
- tmp.Append(payload)
- payload = tmp
+ tmp.Append(pkt.Data)
+ pkt.Data = tmp
}
}
return nil
}
-// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error {
+func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 {
ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- length := uint16(hdr.UsedLength() + payload.Size())
+ length := uint16(hdr.UsedLength() + payloadSize)
id := uint32(0)
if length > header.IPv4MaximumHeaderSize+8 {
// Packets of 68 bytes or less are required by RFC 791 to not be
// fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&ids[hashRoute(r, protocol)%buckets], 1)
+ id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
}
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
TotalLength: length,
ID: uint16(id),
- TTL: ttl,
- Protocol: uint8(protocol),
+ TTL: params.TTL,
+ TOS: params.TOS,
+ Protocol: uint8(params.Protocol),
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
+ return ip
+}
+
+// WritePacket writes a packet to the given destination address and protocol.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+ ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
+ pkt.NetworkHeader = buffer.View(ip)
if loop&stack.PacketLoop != 0 {
- views := make([]buffer.View, 1, 1+len(payload.Views()))
- views[0] = hdr.View()
- views = append(views, payload.Views()...)
- vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
+ // The inbound path expects the network header to still be in
+ // the PacketBuffer's Data field.
+ views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
+ views[0] = pkt.Header.View()
+ views = append(views, pkt.Data.Views()...)
loopedR := r.MakeLoopedRoute()
- e.HandlePacket(&loopedR, vv)
+
+ e.HandlePacket(&loopedR, tcpip.PacketBuffer{
+ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
+ })
+
loopedR.Release()
}
if loop&stack.PacketOut == 0 {
return nil
}
- if hdr.UsedLength()+payload.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
- return e.writePacketFragments(r, gso, hdr, payload, int(e.linkEP.MTU()))
+ if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
+ return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt)
}
- if err := e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber); err != nil {
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
return err
}
r.Stats().IP.PacketsSent.Increment()
return nil
}
+// WritePackets implements stack.NetworkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
+ if loop&stack.PacketLoop != 0 {
+ panic("multiple packets in local loop")
+ }
+ if loop&stack.PacketOut == 0 {
+ return len(pkts), nil
+ }
+
+ for i := range pkts {
+ ip := e.addIPHeader(r, &pkts[i].Header, pkts[i].DataSize, params)
+ pkts[i].NetworkHeader = buffer.View(ip)
+ }
+ n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
+}
+
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
- ip := header.IPv4(payload.First())
- if !ip.IsValid(payload.Size()) {
+ ip := header.IPv4(pkt.Data.First())
+ if !ip.IsValid(pkt.Data.Size()) {
return tcpip.ErrInvalidOptionValue
}
// Always set the total length.
- ip.SetTotalLength(uint16(payload.Size()))
+ ip.SetTotalLength(uint16(pkt.Data.Size()))
// Set the source address when zero.
if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) {
@@ -264,10 +312,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect
// Set the packet ID when zero.
if ip.ID() == 0 {
id := uint32(0)
- if payload.Size() > header.IPv4MaximumHeaderSize+8 {
+ if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 {
// Packets of 68 bytes or less are required by RFC 791 to not be
// fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&ids[hashRoute(r, 0 /* protocol */)%buckets], 1)
+ id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)
}
ip.SetID(uint16(id))
}
@@ -277,37 +325,63 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect
ip.SetChecksum(^ip.CalculateChecksum())
if loop&stack.PacketLoop != 0 {
- e.HandlePacket(r, payload)
+ e.HandlePacket(r, pkt.Clone())
}
if loop&stack.PacketOut == 0 {
return nil
}
- hdr := buffer.NewPrependableFromView(payload.ToView())
r.Stats().IP.PacketsSent.Increment()
- return e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+
+ ip = ip[:ip.HeaderLength()]
+ pkt.Header = buffer.NewPrependableFromView(buffer.View(ip))
+ pkt.Data.TrimFront(int(ip.HeaderLength()))
+ return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
- headerView := vv.First()
+func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
+ headerView := pkt.Data.First()
h := header.IPv4(headerView)
- if !h.IsValid(vv.Size()) {
+ if !h.IsValid(pkt.Data.Size()) {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
+ pkt.NetworkHeader = headerView[:h.HeaderLength()]
hlen := int(h.HeaderLength())
tlen := int(h.TotalLength())
- vv.TrimFront(hlen)
- vv.CapLength(tlen - hlen)
+ pkt.Data.TrimFront(hlen)
+ pkt.Data.CapLength(tlen - hlen)
more := (h.Flags() & header.IPv4FlagMoreFragments) != 0
if more || h.FragmentOffset() != 0 {
+ if pkt.Data.Size() == 0 {
+ // Drop the packet as it's marked as a fragment but has
+ // no payload.
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
// The packet is a fragment, let's try to reassemble it.
- last := h.FragmentOffset() + uint16(vv.Size()) - 1
+ last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1
+ // Drop the packet if the fragmentOffset is incorrect. i.e the
+ // combination of fragmentOffset and pkt.Data.size() causes a
+ // wrap around resulting in last being less than the offset.
+ if last < h.FragmentOffset() {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
var ready bool
- vv, ready = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv)
+ var err error
+ pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data)
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
if !ready {
return
}
@@ -315,24 +389,24 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
headerView.CapLength(hlen)
- e.handleICMP(r, headerView, vv)
+ e.handleICMP(r, pkt)
return
}
r.Stats().IP.PacketsDelivered.Increment()
- e.dispatcher.DeliverTransportPacket(r, p, headerView, vv)
+ e.dispatcher.DeliverTransportPacket(r, p, pkt)
}
// Close cleans up resources associated with the endpoint.
func (e *endpoint) Close() {}
-type protocol struct{}
+type protocol struct {
+ ids []uint32
+ hashIV uint32
-// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported
-// only for tests that short-circuit the stack. Regular use of the protocol is
-// done via the stack, which gets a protocol descriptor from the init() function
-// below.
-func NewProtocol() stack.NetworkProtocol {
- return &protocol{}
+ // defaultTTL is the current default TTL for the protocol. Only the
+ // uint8 portion of it is meaningful and it must be accessed
+ // atomically.
+ defaultTTL uint32
}
// Number returns the ipv4 protocol number.
@@ -358,12 +432,34 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
// SetOption implements NetworkProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch v := option.(type) {
+ case tcpip.DefaultTTLOption:
+ p.SetDefaultTTL(uint8(v))
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// Option implements NetworkProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch v := option.(type) {
+ case *tcpip.DefaultTTLOption:
+ *v = tcpip.DefaultTTLOption(p.DefaultTTL())
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// SetDefaultTTL sets the default TTL for endpoints created with this protocol.
+func (p *protocol) SetDefaultTTL(ttl uint8) {
+ atomic.StoreUint32(&p.defaultTTL, uint32(ttl))
+}
+
+// DefaultTTL returns the default TTL for endpoints created with this protocol.
+func (p *protocol) DefaultTTL() uint8 {
+ return uint8(atomic.LoadUint32(&p.defaultTTL))
}
// calculateMTU calculates the network-layer payload MTU based on the link-layer
@@ -378,7 +474,7 @@ func calculateMTU(mtu uint32) uint32 {
// hashRoute calculates a hash value for the given route. It uses the source &
// destination address, the transport protocol number, and a random initial
// value (generated once on initialization) to generate the hash.
-func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 {
+func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
t := r.LocalAddress
a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
t = r.RemoteAddress
@@ -386,22 +482,16 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 {
return hash.Hash3Words(a, b, uint32(protocol), hashIV)
}
-var (
- ids []uint32
- hashIV uint32
-)
-
-func init() {
- ids = make([]uint32, buckets)
+// NewProtocol returns an IPv4 network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ ids := make([]uint32, buckets)
// Randomly initialize hashIV and the ids.
r := hash.RandN32(1 + buckets)
for i := range ids {
ids[i] = r[i]
}
- hashIV = r[buckets]
+ hashIV := r[buckets]
- stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
- return &protocol{}
- })
+ return &protocol{ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL}
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 1b5a55bea..e900f1b45 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -33,24 +33,20 @@ import (
)
func TestExcludeBroadcast(t *testing.T) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
const defaultMTU = 65536
- id, _ := channel.New(256, defaultMTU, "")
+ ep := stack.LinkEndpoint(channel.New(256, defaultMTU, ""))
if testing.Verbose() {
- id = sniffer.New(id)
+ ep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Broadcast); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
- if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
NIC: 1,
@@ -117,12 +113,12 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.
// comparePayloads compared the contents of all the packets against the contents
// of the source packet.
-func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packetInfo, mtu uint32) {
+func compareFragments(t *testing.T, packets []tcpip.PacketBuffer, sourcePacketInfo tcpip.PacketBuffer, mtu uint32) {
t.Helper()
// Make a complete array of the sourcePacketInfo packet.
source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize])
source = append(source, sourcePacketInfo.Header.View()...)
- source = append(source, sourcePacketInfo.Payload.ToView()...)
+ source = append(source, sourcePacketInfo.Data.ToView()...)
// Make a copy of the IP header, which will be modified in some fields to make
// an expected header.
@@ -136,7 +132,7 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe
for i, packet := range packets {
// Confirm that the packet is valid.
allBytes := packet.Header.View().ToVectorisedView()
- allBytes.Append(packet.Payload)
+ allBytes.Append(packet.Data)
ip := header.IPv4(allBytes.ToView())
if !ip.IsValid(len(ip)) {
t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip))
@@ -177,28 +173,19 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe
type errorChannel struct {
*channel.Endpoint
- Ch chan packetInfo
+ Ch chan tcpip.PacketBuffer
packetCollectorErrors []*tcpip.Error
}
// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket
// will return successive errors from packetCollectorErrors until the list is
// empty and then return nil each time.
-func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) (tcpip.LinkEndpointID, *errorChannel) {
- _, e := channel.New(size, mtu, linkAddr)
- ec := errorChannel{
- Endpoint: e,
- Ch: make(chan packetInfo, size),
+func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
+ return &errorChannel{
+ Endpoint: channel.New(size, mtu, linkAddr),
+ Ch: make(chan tcpip.PacketBuffer, size),
packetCollectorErrors: packetCollectorErrors,
}
-
- return stack.RegisterLinkEndpoint(e), &ec
-}
-
-// packetInfo holds all the information about an outbound packet.
-type packetInfo struct {
- Header buffer.Prependable
- Payload buffer.VectorisedView
}
// Drain removes all outbound packets from the channel and counts them.
@@ -215,14 +202,9 @@ func (e *errorChannel) Drain() int {
}
// WritePacket stores outbound packets into the channel.
-func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
- p := packetInfo{
- Header: hdr,
- Payload: payload,
- }
-
+func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
select {
- case e.Ch <- p:
+ case e.Ch <- pkt:
default:
}
@@ -241,10 +223,11 @@ type context struct {
func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context {
// Make the packet and write it.
- s := stack.New([]string{ipv4.ProtocolName}, []string{}, stack.Options{})
- _, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
- linkEPId := stack.RegisterLinkEndpoint(linkEP)
- s.CreateNIC(1, linkEPId)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ })
+ ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
+ s.CreateNIC(1, ep)
const (
src = "\x10\x00\x00\x01"
dst = "\x10\x00\x00\x02"
@@ -266,7 +249,7 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32
}
return context{
Route: r,
- linkEP: linkEP,
+ linkEP: ep,
}
}
@@ -298,18 +281,21 @@ func TestFragmentation(t *testing.T) {
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
- source := packetInfo{
+ source := tcpip.PacketBuffer{
Header: hdr,
// Save the source payload because WritePacket will modify it.
- Payload: payload.Clone([]buffer.View{}),
+ Data: payload.Clone(nil),
}
c := buildContext(t, nil, ft.mtu)
- err := c.Route.WritePacket(ft.gso, hdr, payload, tcp.ProtocolNumber, 42)
+ err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: payload,
+ })
if err != nil {
t.Errorf("err got %v, want %v", err, nil)
}
- var results []packetInfo
+ var results []tcpip.PacketBuffer
L:
for {
select {
@@ -351,7 +337,10 @@ func TestFragmentationErrors(t *testing.T) {
t.Run(ft.description, func(t *testing.T) {
hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
- err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, tcp.ProtocolNumber, 42)
+ err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: payload,
+ })
for i := 0; i < len(ft.packetCollectorErrors)-1; i++ {
if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want {
t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want)
@@ -368,3 +357,119 @@ func TestFragmentationErrors(t *testing.T) {
})
}
}
+
+func TestInvalidFragments(t *testing.T) {
+ // These packets have both IHL and TotalLength set to 0.
+ testCases := []struct {
+ name string
+ packets [][]byte
+ wantMalformedIPPackets uint64
+ wantMalformedFragments uint64
+ }{
+ {
+ "ihl_totallen_zero_valid_frag_offset",
+ [][]byte{
+ {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x7d, 0x30, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 1,
+ 0,
+ },
+ {
+ "ihl_totallen_zero_invalid_frag_offset",
+ [][]byte{
+ {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x20, 0x00, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 1,
+ 0,
+ },
+ {
+ // Total Length of 37(20 bytes IP header + 17 bytes of
+ // payload)
+ // Frag Offset of 0x1ffe = 8190*8 = 65520
+ // Leading to the fragment end to be past 65535.
+ "ihl_totallen_valid_invalid_frag_offset_1",
+ [][]byte{
+ {0x45, 0x30, 0x00, 0x25, 0x6c, 0x74, 0x1f, 0xfe, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 1,
+ 1,
+ },
+ // The following 3 tests were found by running a fuzzer and were
+ // triggering a panic in the IPv4 reassembler code.
+ {
+ "ihl_less_than_ipv4_minimum_size_1",
+ [][]byte{
+ {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0x0, 0xf3, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 2,
+ 0,
+ },
+ {
+ "ihl_less_than_ipv4_minimum_size_2",
+ [][]byte{
+ {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x12, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 2,
+ 0,
+ },
+ {
+ "ihl_less_than_ipv4_minimum_size_3",
+ [][]byte{
+ {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x30, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 2,
+ 0,
+ },
+ {
+ "fragment_with_short_total_len_extra_payload",
+ [][]byte{
+ {0x46, 0x30, 0x00, 0x30, 0x30, 0x40, 0x0e, 0x12, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ {0x46, 0x30, 0x00, 0x18, 0x30, 0x40, 0x20, 0x00, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ },
+ 1,
+ 1,
+ },
+ {
+ "multiple_fragments_with_more_fragments_set_to_false",
+ [][]byte{
+ {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ },
+ 1,
+ 1,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ const nicID tcpip.NICID = 42
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{
+ ipv4.NewProtocol(),
+ },
+ })
+
+ var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30})
+ var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31})
+ ep := channel.New(10, 1500, linkAddr)
+ s.CreateNIC(nicID, sniffer.New(ep))
+
+ for _, pkt := range tc.packets {
+ ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, tcpip.PacketBuffer{
+ Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}),
+ })
+ }
+
+ if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want {
+ t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want)
+ }
+ if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), tc.wantMalformedFragments; got != want {
+ t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index c71b69123..e4e273460 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -9,9 +10,7 @@ go_library(
"ipv6.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv6",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
@@ -25,6 +24,7 @@ go_test(
size = "small",
srcs = [
"icmp_test.go",
+ "ipv6_test.go",
"ndp_test.go",
],
embed = [":ipv6"],
@@ -36,6 +36,7 @@ go_test(
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/udp",
"//pkg/waiter",
],
)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index b4d0295bf..1c3410618 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -21,21 +21,12 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
-const (
- // ndpHopLimit is the expected IP hop limit value of 255 for received
- // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1,
- // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently
- // drop the NDP packet. All outgoing NDP packets must use this value for
- // its IP hop limit field.
- ndpHopLimit = 255
-)
-
// handleControl handles the case when an ICMP packet contains the headers of
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
// packet.
-func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
- h := header.IPv6(vv.First())
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+ h := header.IPv6(pkt.Data.First())
// We don't use IsValid() here because ICMP only requires that up to
// 1280 bytes of the original packet be included. So it's likely that it
@@ -49,10 +40,10 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.
// Skip the IP header, then handle the fragmentation header if there
// is one.
- vv.TrimFront(header.IPv6MinimumSize)
+ pkt.Data.TrimFront(header.IPv6MinimumSize)
p := h.TransportProtocol()
if p == header.IPv6FragmentHeader {
- f := header.IPv6Fragment(vv.First())
+ f := header.IPv6Fragment(pkt.Data.First())
if !f.IsValid() || f.FragmentOffset() != 0 {
// We can't handle fragments that aren't at offset 0
// because they don't have the transport headers.
@@ -61,35 +52,54 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.
// Skip fragmentation header and find out the actual protocol
// number.
- vv.TrimFront(header.IPv6FragmentHeaderSize)
+ pkt.Data.TrimFront(header.IPv6FragmentHeaderSize)
p = f.TransportProtocol()
}
// Deliver the control packet to the transport endpoint.
- e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
-func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
+func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.PacketBuffer) {
stats := r.Stats().ICMP
sent := stats.V6PacketsSent
received := stats.V6PacketsReceived
- v := vv.First()
+ v := pkt.Data.First()
if len(v) < header.ICMPv6MinimumSize {
received.Invalid.Increment()
return
}
h := header.ICMPv6(v)
+ iph := header.IPv6(netHeader)
+
+ // Validate ICMPv6 checksum before processing the packet.
+ //
+ // Only the first view in vv is accounted for by h. To account for the
+ // rest of vv, a shallow copy is made and the first view is removed.
+ // This copy is used as extra payload during the checksum calculation.
+ payload := pkt.Data
+ payload.RemoveFirst()
+ if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want {
+ received.Invalid.Increment()
+ return
+ }
// As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and
// 8.1, nodes MUST silently drop NDP packets where the Hop Limit field
- // in the IPv6 header is not set to 255.
+ // in the IPv6 header is not set to 255, or the ICMPv6 Code field is not
+ // set to 0.
switch h.Type() {
case header.ICMPv6NeighborSolicit,
header.ICMPv6NeighborAdvert,
header.ICMPv6RouterSolicit,
header.ICMPv6RouterAdvert,
header.ICMPv6RedirectMsg:
- if header.IPv6(netHeader).HopLimit() != ndpHopLimit {
+ if iph.HopLimit() != header.NDPHopLimit {
+ received.Invalid.Increment()
+ return
+ }
+
+ if h.Code() != 0 {
received.Invalid.Increment()
return
}
@@ -103,9 +113,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.Invalid.Increment()
return
}
- vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
+ pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
mtu := h.MTU()
- e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
+ e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt)
case header.ICMPv6DstUnreachable:
received.DstUnreachable.Increment()
@@ -113,33 +123,80 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.Invalid.Increment()
return
}
- vv.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
+ pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
switch h.Code() {
case header.ICMPv6PortUnreachable:
- e.handleControl(stack.ControlPortUnreachable, 0, vv)
+ e.handleControl(stack.ControlPortUnreachable, 0, pkt)
}
case header.ICMPv6NeighborSolicit:
received.NeighborSolicit.Increment()
-
if len(v) < header.ICMPv6NeighborSolicitMinimumSize {
received.Invalid.Increment()
return
}
- targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize])
- if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 {
+
+ ns := header.NDPNeighborSolicit(h.NDPPayload())
+ targetAddr := ns.TargetAddress()
+ s := r.Stack()
+ rxNICID := r.NICID()
+
+ isTentative, err := s.IsAddrTentative(rxNICID, targetAddr)
+ if err != nil {
+ // We will only get an error if rxNICID is unrecognized,
+ // which should not happen. For now short-circuit this
+ // packet.
+ //
+ // TODO(b/141002840): Handle this better?
+ return
+ }
+
+ if isTentative {
+ // If the target address is tentative and the source
+ // of the packet is a unicast (specified) address, then
+ // the source of the packet is attempting to perform
+ // address resolution on the target. In this case, the
+ // solicitation is silently ignored, as per RFC 4862
+ // section 5.4.3.
+ //
+ // If the target address is tentative and the source of
+ // the packet is the unspecified address (::), then we
+ // know another node is also performing DAD for the
+ // same address (since targetAddr is tentative for us,
+ // we know we are also performing DAD on it). In this
+ // case we let the stack know so it can handle such a
+ // scenario and do nothing further with the NDP NS.
+ if iph.SourceAddress() == header.IPv6Any {
+ s.DupTentativeAddrDetected(rxNICID, targetAddr)
+ }
+
+ // Do not handle neighbor solicitations targeted
+ // to an address that is tentative on the received
+ // NIC any further.
+ return
+ }
+
+ // At this point we know that targetAddr is not tentative on
+ // rxNICID so the packet is processed as defined in RFC 4861,
+ // as per RFC 4862 section 5.4.3.
+
+ if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 {
// We don't have a useful answer; the best we can do is ignore the request.
return
}
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
- pkt.SetType(header.ICMPv6NeighborAdvert)
- pkt[icmpV6FlagOffset] = ndpSolicitedFlag | ndpOverrideFlag
- copy(pkt[icmpV6OptOffset-len(targetAddr):], targetAddr)
- pkt[icmpV6OptOffset] = ndpOptDstLinkAddr
- pkt[icmpV6LengthOffset] = 1
- copy(pkt[icmpV6LengthOffset+1:], r.LocalLinkAddress[:])
+ optsSerializer := header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress[:]),
+ }
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()))
+ packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ packet.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(packet.NDPPayload())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(true)
+ na.SetTargetAddress(targetAddr)
+ opts := na.Options()
+ opts.Serialize(optsSerializer)
// ICMPv6 Neighbor Solicit messages are always sent to
// specially crafted IPv6 multicast addresses. As a result, the
@@ -152,9 +209,26 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
r := r.Clone()
defer r.Release()
r.LocalAddress = targetAddr
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil {
+ // TODO(tamird/ghanan): there exists an explicit NDP option that is
+ // used to update the neighbor table with link addresses for a
+ // neighbor from an NS (see the Source Link Layer option RFC
+ // 4861 section 4.6.1 and section 7.2.3).
+ //
+ // Furthermore, the entirety of NDP handling here seems to be
+ // contradicted by RFC 4861.
+ e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress)
+
+ // RFC 4861 Neighbor Discovery for IP version 6 (IPv6)
+ //
+ // 7.1.2. Validation of Neighbor Advertisements
+ //
+ // The IP Hop Limit field has a value of 255, i.e., the packet
+ // could not possibly have been forwarded by a router.
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ }); err != nil {
sent.Dropped.Increment()
return
}
@@ -166,10 +240,45 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.Invalid.Increment()
return
}
- targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize])
- e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress)
+
+ na := header.NDPNeighborAdvert(h.NDPPayload())
+ targetAddr := na.TargetAddress()
+ stack := r.Stack()
+ rxNICID := r.NICID()
+
+ isTentative, err := stack.IsAddrTentative(rxNICID, targetAddr)
+ if err != nil {
+ // We will only get an error if rxNICID is unrecognized,
+ // which should not happen. For now short-circuit this
+ // packet.
+ //
+ // TODO(b/141002840): Handle this better?
+ return
+ }
+
+ if isTentative {
+ // We just got an NA from a node that owns an address we
+ // are performing DAD on, implying the address is not
+ // unique. In this case we let the stack know so it can
+ // handle such a scenario and do nothing furthur with
+ // the NDP NA.
+ stack.DupTentativeAddrDetected(rxNICID, targetAddr)
+ return
+ }
+
+ // At this point we know that the targetAddress is not tentative
+ // on rxNICID. However, targetAddr may still be assigned to
+ // rxNICID but not tentative (it could be permanent). Such a
+ // scenario is beyond the scope of RFC 4862. As such, we simply
+ // ignore such a scenario for now and proceed as normal.
+ //
+ // TODO(b/143147598): Handle the scenario described above. Also
+ // inform the netstack integration that a duplicate address was
+ // detected outside of DAD.
+
+ e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, r.RemoteLinkAddress)
if targetAddr != r.RemoteAddress {
- e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
+ e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress)
}
case header.ICMPv6EchoRequest:
@@ -178,14 +287,16 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.Invalid.Increment()
return
}
-
- vv.TrimFront(header.ICMPv6EchoMinimumSize)
+ pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize)
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
- copy(pkt, h)
- pkt.SetType(header.ICMPv6EchoReply)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, vv))
- if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil {
+ packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
+ copy(packet, h)
+ packet.SetType(header.ICMPv6EchoReply)
+ packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data))
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: pkt.Data,
+ }); err != nil {
sent.Dropped.Increment()
return
}
@@ -197,7 +308,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.Invalid.Increment()
return
}
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, netHeader, vv)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt)
case header.ICMPv6TimeExceeded:
received.TimeExceeded.Increment()
@@ -209,8 +320,51 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.RouterSolicit.Increment()
case header.ICMPv6RouterAdvert:
+ routerAddr := iph.SourceAddress()
+
+ //
+ // Validate the RA as per RFC 4861 section 6.1.2.
+ //
+
+ // Is the IP Source Address a link-local address?
+ if !header.IsV6LinkLocalAddress(routerAddr) {
+ // ...No, silently drop the packet.
+ received.Invalid.Increment()
+ return
+ }
+
+ p := h.NDPPayload()
+
+ // Is the NDP payload of sufficient size to hold a Router
+ // Advertisement?
+ if len(p) < header.NDPRAMinimumSize {
+ // ...No, silently drop the packet.
+ received.Invalid.Increment()
+ return
+ }
+
+ ra := header.NDPRouterAdvert(p)
+ opts := ra.Options()
+
+ // Are options valid as per the wire format?
+ if _, err := opts.Iter(true); err != nil {
+ // ...No, silently drop the packet.
+ received.Invalid.Increment()
+ return
+ }
+
+ //
+ // At this point, we have a valid Router Advertisement, as far
+ // as RFC 4861 section 6.1.2 is concerned.
+ //
+
received.RouterAdvert.Increment()
+ // Tell the NIC to handle the RA.
+ stack := r.Stack()
+ rxNICID := r.NICID()
+ stack.HandleNDPRA(rxNICID, routerAddr, ra)
+
case header.ICMPv6RedirectMsg:
received.RedirectMsg.Increment()
@@ -262,13 +416,15 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: defaultIPv6HopLimit,
+ HopLimit: header.NDPHopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
// TODO(stijlist): count this in ICMP stats.
- return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+ return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{
+ Header: hdr,
+ })
}
// ResolveStaticAddress implements stack.LinkAddressResolver.
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 227a65cf2..335f634d5 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -15,7 +15,6 @@
package ipv6
import (
- "fmt"
"reflect"
"strings"
"testing"
@@ -31,7 +30,7 @@ import (
)
const (
- linkAddr0 = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06")
+ linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
)
@@ -56,7 +55,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
-func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, tcpip.PacketBuffer) *tcpip.Error {
return nil
}
@@ -66,7 +65,7 @@ type stubDispatcher struct {
stack.TransportDispatcher
}
-func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, buffer.View, buffer.VectorisedView) {
+func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, tcpip.PacketBuffer) {
}
type stubLinkAddressCache struct {
@@ -81,10 +80,12 @@ func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.Li
}
func TestICMPCounts(t *testing.T) {
- s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ })
{
- id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{})
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
}
if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
@@ -130,7 +131,7 @@ func TestICMPCounts(t *testing.T) {
{header.ICMPv6EchoRequest, header.ICMPv6EchoMinimumSize},
{header.ICMPv6EchoReply, header.ICMPv6EchoMinimumSize},
{header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize},
- {header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize},
+ {header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize},
{header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize},
{header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize},
{header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize},
@@ -142,11 +143,13 @@ func TestICMPCounts(t *testing.T) {
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: r.DefaultTTL(),
+ HopLimit: header.NDPHopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(&r, hdr.View().ToVectorisedView())
+ ep.HandlePacket(&r, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
}
for _, typ := range types {
@@ -177,13 +180,10 @@ func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) {
t := v.Type()
for i := 0; i < v.NumField(); i++ {
v := v.Field(i)
- switch v.Kind() {
- case reflect.Ptr:
- f(t.Field(i).Name, v.Interface().(*tcpip.StatCounter))
- case reflect.Struct:
+ if s, ok := v.Interface().(*tcpip.StatCounter); ok {
+ f(t.Field(i).Name, s)
+ } else {
visitStats(v, f)
- default:
- panic(fmt.Sprintf("unexpected type %s", v.Type()))
}
}
}
@@ -206,41 +206,38 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab
func newTestContext(t *testing.T) *testContext {
c := &testContext{
- s0: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}),
- s1: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}),
+ s0: stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ }),
+ s1: stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ }),
}
const defaultMTU = 65536
- _, linkEP0 := channel.New(256, defaultMTU, linkAddr0)
- c.linkEP0 = linkEP0
- wrappedEP0 := endpointWithResolutionCapability{LinkEndpoint: linkEP0}
- id0 := stack.RegisterLinkEndpoint(wrappedEP0)
+ c.linkEP0 = channel.New(256, defaultMTU, linkAddr0)
+
+ wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0})
if testing.Verbose() {
- id0 = sniffer.New(id0)
+ wrappedEP0 = sniffer.New(wrappedEP0)
}
- if err := c.s0.CreateNIC(1, id0); err != nil {
+ if err := c.s0.CreateNIC(1, wrappedEP0); err != nil {
t.Fatalf("CreateNIC s0: %v", err)
}
if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
t.Fatalf("AddAddress lladdr0: %v", err)
}
- if err := c.s0.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr0)); err != nil {
- t.Fatalf("AddAddress sn lladdr0: %v", err)
- }
- _, linkEP1 := channel.New(256, defaultMTU, linkAddr1)
- c.linkEP1 = linkEP1
- wrappedEP1 := endpointWithResolutionCapability{LinkEndpoint: linkEP1}
- id1 := stack.RegisterLinkEndpoint(wrappedEP1)
- if err := c.s1.CreateNIC(1, id1); err != nil {
+ c.linkEP1 = channel.New(256, defaultMTU, linkAddr1)
+ wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1})
+ if err := c.s1.CreateNIC(1, wrappedEP1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil {
t.Fatalf("AddAddress lladdr1: %v", err)
}
- if err := c.s1.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr1)); err != nil {
- t.Fatalf("AddAddress sn lladdr1: %v", err)
- }
subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
if err != nil {
@@ -279,20 +276,22 @@ type routeArgs struct {
func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) {
t.Helper()
- pkt := <-args.src.C
+ pi := <-args.src.C
{
- views := []buffer.View{pkt.Header, pkt.Payload}
- size := len(pkt.Header) + len(pkt.Payload)
+ views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()}
+ size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size()
vv := buffer.NewVectorisedView(size, views)
- args.dst.InjectLinkAddr(pkt.Proto, args.dst.LinkAddress(), vv)
+ args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), tcpip.PacketBuffer{
+ Data: vv,
+ })
}
- if pkt.Proto != ProtocolNumber {
- t.Errorf("unexpected protocol number %d", pkt.Proto)
+ if pi.Proto != ProtocolNumber {
+ t.Errorf("unexpected protocol number %d", pi.Proto)
return
}
- ipv6 := header.IPv6(pkt.Header)
+ ipv6 := header.IPv6(pi.Pkt.Header.View())
transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader())
if transProto != header.ICMPv6ProtocolNumber {
t.Errorf("unexpected transport protocol number %d", transProto)
@@ -364,3 +363,537 @@ func TestLinkResolution(t *testing.T) {
routeICMPv6Packet(t, args, nil)
}
}
+
+func TestICMPChecksumValidationSimple(t *testing.T) {
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ }{
+ {
+ "DstUnreachable",
+ header.ICMPv6DstUnreachable,
+ header.ICMPv6DstUnreachableMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.DstUnreachable
+ },
+ },
+ {
+ "PacketTooBig",
+ header.ICMPv6PacketTooBig,
+ header.ICMPv6PacketTooBigMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.PacketTooBig
+ },
+ },
+ {
+ "TimeExceeded",
+ header.ICMPv6TimeExceeded,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.TimeExceeded
+ },
+ },
+ {
+ "ParamProblem",
+ header.ICMPv6ParamProblem,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.ParamProblem
+ },
+ },
+ {
+ "EchoRequest",
+ header.ICMPv6EchoRequest,
+ header.ICMPv6EchoMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoRequest
+ },
+ },
+ {
+ "EchoReply",
+ header.ICMPv6EchoReply,
+ header.ICMPv6EchoMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoReply
+ },
+ },
+ {
+ "RouterSolicit",
+ header.ICMPv6RouterSolicit,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterSolicit
+ },
+ },
+ {
+ "RouterAdvert",
+ header.ICMPv6RouterAdvert,
+ header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterAdvert
+ },
+ },
+ {
+ "NeighborSolicit",
+ header.ICMPv6NeighborSolicit,
+ header.ICMPv6NeighborSolicitMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborSolicit
+ },
+ },
+ {
+ "NeighborAdvert",
+ header.ICMPv6NeighborAdvert,
+ header.ICMPv6NeighborAdvertSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborAdvert
+ },
+ },
+ {
+ "RedirectMsg",
+ header.ICMPv6RedirectMsg,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RedirectMsg
+ },
+ },
+ }
+
+ for _, typ := range types {
+ t.Run(typ.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr0)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
+
+ handleIPv6Payload := func(typ header.ICMPv6Type, size int, checksum bool) {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + size)
+ pkt := header.ICMPv6(hdr.Prepend(size))
+ pkt.SetType(typ)
+ if checksum {
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ }
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(size),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+ e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ }
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ typStat := typ.statCounter(stats)
+
+ // Initial stat counts should be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Without setting checksum, the incoming packet should
+ // be invalid.
+ handleIPv6Payload(typ.typ, typ.size, false)
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ // Rx count of type typ.typ should not have increased.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // When checksum is set, it should be received.
+ handleIPv6Payload(typ.typ, typ.size, true)
+ if got := typStat.Value(); got != 1 {
+ t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ }
+ // Invalid count should not have increased again.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ })
+ }
+}
+
+func TestICMPChecksumValidationWithPayload(t *testing.T) {
+ const simpleBodySize = 64
+ simpleBody := func(view buffer.View) {
+ for i := 0; i < simpleBodySize; i++ {
+ view[i] = uint8(i)
+ }
+ }
+
+ const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize
+ errorICMPBody := func(view buffer.View) {
+ ip := header.IPv6(view)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: simpleBodySize,
+ NextHeader: 10,
+ HopLimit: 20,
+ SrcAddr: lladdr0,
+ DstAddr: lladdr1,
+ })
+ simpleBody(view[header.IPv6MinimumSize:])
+ }
+
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ payloadSize int
+ payload func(buffer.View)
+ }{
+ {
+ "DstUnreachable",
+ header.ICMPv6DstUnreachable,
+ header.ICMPv6DstUnreachableMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.DstUnreachable
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "PacketTooBig",
+ header.ICMPv6PacketTooBig,
+ header.ICMPv6PacketTooBigMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.PacketTooBig
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "TimeExceeded",
+ header.ICMPv6TimeExceeded,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.TimeExceeded
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "ParamProblem",
+ header.ICMPv6ParamProblem,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.ParamProblem
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "EchoRequest",
+ header.ICMPv6EchoRequest,
+ header.ICMPv6EchoMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoRequest
+ },
+ simpleBodySize,
+ simpleBody,
+ },
+ {
+ "EchoReply",
+ header.ICMPv6EchoReply,
+ header.ICMPv6EchoMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoReply
+ },
+ simpleBodySize,
+ simpleBody,
+ },
+ }
+
+ for _, typ := range types {
+ t.Run(typ.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr0)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
+
+ handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
+ icmpSize := size + payloadSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
+ pkt := header.ICMPv6(hdr.Prepend(icmpSize))
+ pkt.SetType(typ)
+ payloadFn(pkt.Payload())
+
+ if checksum {
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ }
+
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(icmpSize),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+ e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ }
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ typStat := typ.statCounter(stats)
+
+ // Initial stat counts should be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Without setting checksum, the incoming packet should
+ // be invalid.
+ handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false)
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ // Rx count of type typ.typ should not have increased.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // When checksum is set, it should be received.
+ handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
+ if got := typStat.Value(); got != 1 {
+ t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ }
+ // Invalid count should not have increased again.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ })
+ }
+}
+
+func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
+ const simpleBodySize = 64
+ simpleBody := func(view buffer.View) {
+ for i := 0; i < simpleBodySize; i++ {
+ view[i] = uint8(i)
+ }
+ }
+
+ const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize
+ errorICMPBody := func(view buffer.View) {
+ ip := header.IPv6(view)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: simpleBodySize,
+ NextHeader: 10,
+ HopLimit: 20,
+ SrcAddr: lladdr0,
+ DstAddr: lladdr1,
+ })
+ simpleBody(view[header.IPv6MinimumSize:])
+ }
+
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ payloadSize int
+ payload func(buffer.View)
+ }{
+ {
+ "DstUnreachable",
+ header.ICMPv6DstUnreachable,
+ header.ICMPv6DstUnreachableMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.DstUnreachable
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "PacketTooBig",
+ header.ICMPv6PacketTooBig,
+ header.ICMPv6PacketTooBigMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.PacketTooBig
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "TimeExceeded",
+ header.ICMPv6TimeExceeded,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.TimeExceeded
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "ParamProblem",
+ header.ICMPv6ParamProblem,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.ParamProblem
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "EchoRequest",
+ header.ICMPv6EchoRequest,
+ header.ICMPv6EchoMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoRequest
+ },
+ simpleBodySize,
+ simpleBody,
+ },
+ {
+ "EchoReply",
+ header.ICMPv6EchoReply,
+ header.ICMPv6EchoMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoReply
+ },
+ simpleBodySize,
+ simpleBody,
+ },
+ }
+
+ for _, typ := range types {
+ t.Run(typ.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr0)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
+
+ handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + size)
+ pkt := header.ICMPv6(hdr.Prepend(size))
+ pkt.SetType(typ)
+
+ payload := buffer.NewView(payloadSize)
+ payloadFn(payload)
+
+ if checksum {
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView()))
+ }
+
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(size + payloadSize),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+ e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}),
+ })
+ }
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ typStat := typ.statCounter(stats)
+
+ // Initial stat counts should be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Without setting checksum, the incoming packet should
+ // be invalid.
+ handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false)
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ // Rx count of type typ.typ should not have increased.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // When checksum is set, it should be received.
+ handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
+ if got := typStat.Value(); got != 1 {
+ t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ }
+ // Invalid count should not have increased again.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 331a8bdaa..dd31f0fb7 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -14,13 +14,15 @@
// Package ipv6 contains the implementation of the ipv6 network protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing ipv6.ProtocolName (or "ipv6") as one of the
-// network protocols when calling stack.New(). Then endpoints can be created
-// by passing ipv6.ProtocolNumber as the network protocol number when calling
+// activated on the stack by passing ipv6.NewProtocol() as one of the network
+// protocols when calling stack.New(). Then endpoints can be created by passing
+// ipv6.ProtocolNumber as the network protocol number when calling
// Stack.NewEndpoint().
package ipv6
import (
+ "sync/atomic"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -28,9 +30,6 @@ import (
)
const (
- // ProtocolName is the string representation of the ipv6 protocol name.
- ProtocolName = "ipv6"
-
// ProtocolNumber is the ipv6 protocol number.
ProtocolNumber = header.IPv6ProtocolNumber
@@ -38,23 +37,24 @@ const (
// PayloadLength field of the ipv6 header.
maxPayloadSize = 0xffff
- // defaultIPv6HopLimit is the default hop limit for IPv6 Packets
- // egressed by Netstack.
- defaultIPv6HopLimit = 255
+ // DefaultTTL is the default hop limit for IPv6 Packets egressed by
+ // Netstack.
+ DefaultTTL = 64
)
type endpoint struct {
- nicid tcpip.NICID
+ nicID tcpip.NICID
id stack.NetworkEndpointID
prefixLen int
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
dispatcher stack.TransportDispatcher
+ protocol *protocol
}
// DefaultTTL is the default hop limit for this endpoint.
func (e *endpoint) DefaultTTL() uint8 {
- return 255
+ return e.protocol.DefaultTTL()
}
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
@@ -65,7 +65,7 @@ func (e *endpoint) MTU() uint32 {
// NICID returns the ID of the NIC this endpoint belongs to.
func (e *endpoint) NICID() tcpip.NICID {
- return e.nicid
+ return e.nicID
}
// ID returns the ipv6 endpoint ID.
@@ -97,25 +97,37 @@ func (e *endpoint) GSOMaxSize() uint32 {
return 0
}
-// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error {
- length := uint16(hdr.UsedLength() + payload.Size())
+func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv6 {
+ length := uint16(hdr.UsedLength() + payloadSize)
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
- NextHeader: uint8(protocol),
- HopLimit: ttl,
+ NextHeader: uint8(params.Protocol),
+ HopLimit: params.TTL,
+ TrafficClass: params.TOS,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
+ return ip
+}
+
+// WritePacket writes a packet to the given destination address and protocol.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+ ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
+ pkt.NetworkHeader = buffer.View(ip)
if loop&stack.PacketLoop != 0 {
- views := make([]buffer.View, 1, 1+len(payload.Views()))
- views[0] = hdr.View()
- views = append(views, payload.Views()...)
- vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
+ // The inbound path expects the network header to still be in
+ // the PacketBuffer's Data field.
+ views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
+ views[0] = pkt.Header.View()
+ views = append(views, pkt.Data.Views()...)
loopedR := r.MakeLoopedRoute()
- e.HandlePacket(&loopedR, vv)
+
+ e.HandlePacket(&loopedR, tcpip.PacketBuffer{
+ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
+ })
+
loopedR.Release()
}
if loop&stack.PacketOut == 0 {
@@ -123,49 +135,68 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
}
r.Stats().IP.PacketsSent.Increment()
- return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber)
+ return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
+ if loop&stack.PacketLoop != 0 {
+ panic("not implemented")
+ }
+ if loop&stack.PacketOut == 0 {
+ return len(pkts), nil
+ }
+
+ for i := range pkts {
+ hdr := &pkts[i].Header
+ size := pkts[i].DataSize
+ ip := e.addIPHeader(r, hdr, size, params)
+ pkts[i].NetworkHeader = buffer.View(ip)
+ }
+
+ n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
}
// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
// supported by IPv6.
-func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
+func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
// TODO(b/119580726): Support IPv6 header-included packets.
return tcpip.ErrNotSupported
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
- headerView := vv.First()
+func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
+ headerView := pkt.Data.First()
h := header.IPv6(headerView)
- if !h.IsValid(vv.Size()) {
+ if !h.IsValid(pkt.Data.Size()) {
return
}
- vv.TrimFront(header.IPv6MinimumSize)
- vv.CapLength(int(h.PayloadLength()))
+ pkt.NetworkHeader = headerView[:header.IPv6MinimumSize]
+ pkt.Data.TrimFront(header.IPv6MinimumSize)
+ pkt.Data.CapLength(int(h.PayloadLength()))
p := h.TransportProtocol()
if p == header.ICMPv6ProtocolNumber {
- e.handleICMP(r, headerView, vv)
+ e.handleICMP(r, headerView, pkt)
return
}
r.Stats().IP.PacketsDelivered.Increment()
- e.dispatcher.DeliverTransportPacket(r, p, headerView, vv)
+ e.dispatcher.DeliverTransportPacket(r, p, pkt)
}
// Close cleans up resources associated with the endpoint.
func (*endpoint) Close() {}
-type protocol struct{}
-
-// NewProtocol creates a new protocol ipv6 protocol descriptor. This is exported
-// only for tests that short-circuit the stack. Regular use of the protocol is
-// done via the stack, which gets a protocol descriptor from the init() function
-// below.
-func NewProtocol() stack.NetworkProtocol {
- return &protocol{}
+type protocol struct {
+ // defaultTTL is the current default TTL for the protocol. Only the
+ // uint8 portion of it is meaningful and it must be accessed
+ // atomically.
+ defaultTTL uint32
}
// Number returns the ipv6 protocol number.
@@ -190,25 +221,48 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// NewEndpoint creates a new ipv6 endpoint.
-func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
return &endpoint{
- nicid: nicid,
+ nicID: nicID,
id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
prefixLen: addrWithPrefix.PrefixLen,
linkEP: linkEP,
linkAddrCache: linkAddrCache,
dispatcher: dispatcher,
+ protocol: p,
}, nil
}
// SetOption implements NetworkProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch v := option.(type) {
+ case tcpip.DefaultTTLOption:
+ p.SetDefaultTTL(uint8(v))
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// Option implements NetworkProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch v := option.(type) {
+ case *tcpip.DefaultTTLOption:
+ *v = tcpip.DefaultTTLOption(p.DefaultTTL())
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// SetDefaultTTL sets the default TTL for endpoints created with this protocol.
+func (p *protocol) SetDefaultTTL(ttl uint8) {
+ atomic.StoreUint32(&p.defaultTTL, uint32(ttl))
+}
+
+// DefaultTTL returns the default TTL for endpoints created with this protocol.
+func (p *protocol) DefaultTTL() uint8 {
+ return uint8(atomic.LoadUint32(&p.defaultTTL))
}
// calculateMTU calculates the network-layer payload MTU based on the link-layer
@@ -221,8 +275,7 @@ func calculateMTU(mtu uint32) uint32 {
return maxPayloadSize
}
-func init() {
- stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
- return &protocol{}
- })
+// NewProtocol returns an IPv6 network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{defaultTTL: DefaultTTL}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
new file mode 100644
index 000000000..1cbfa7278
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -0,0 +1,270 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ // The least significant 3 bytes are the same as addr2 so both addr2 and
+ // addr3 will have the same solicited-node address.
+ addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02"
+)
+
+// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the
+// expected Neighbor Advertisement received count after receiving the packet.
+func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) {
+ t.Helper()
+
+ // Receive ICMP packet.
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+
+ if got := stats.NeighborAdvert.Value(); got != want {
+ t.Fatalf("got NeighborAdvert = %d, want = %d", got, want)
+ }
+}
+
+// testReceiveUDP tests receiving a UDP packet from src to dst. want is the
+// expected UDP received count after receiving the packet.
+func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) {
+ t.Helper()
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+
+ if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil {
+ t.Fatalf("ep.Bind(...) failed: %v", err)
+ }
+
+ // Receive UDP Packet.
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
+ u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: header.UDPMinimumSize,
+ })
+
+ // UDP pseudo-header checksum.
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize)
+
+ // UDP checksum
+ sum = header.Checksum(header.UDP([]byte{}), sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ stat := s.Stats().UDP.PacketsReceived
+
+ if got := stat.Value(); got != want {
+ t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want)
+ }
+}
+
+// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and
+// UDP packets destined to the IPv6 link-local all-nodes multicast address.
+func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
+ tests := []struct {
+ name string
+ protocolFactory stack.TransportProtocol
+ rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
+ }{
+ {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
+ {"UDP", udp.NewProtocol(), testReceiveUDP},
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ })
+ e := channel.New(10, 1280, linkAddr1)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ // Should receive a packet destined to the all-nodes
+ // multicast address.
+ test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1)
+ })
+ }
+}
+
+// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP
+// packets destined to the IPv6 solicited-node address of an assigned IPv6
+// address.
+func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
+ tests := []struct {
+ name string
+ protocolFactory stack.TransportProtocol
+ rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
+ }{
+ {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
+ {"UDP", udp.NewProtocol(), testReceiveUDP},
+ }
+
+ snmc := header.SolicitedNodeAddr(addr2)
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ })
+ e := channel.New(10, 1280, linkAddr1)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ // Should not receive a packet destined to the solicited
+ // node address of addr2/addr3 yet as we haven't added
+ // those addresses.
+ test.rxf(t, s, e, addr1, snmc, 0)
+
+ if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err)
+ }
+
+ // Should receive a packet destined to the solicited
+ // node address of addr2/addr3 now that we have added
+ // added addr2.
+ test.rxf(t, s, e, addr1, snmc, 1)
+
+ if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err)
+ }
+
+ // Should still receive a packet destined to the
+ // solicited node address of addr2/addr3 now that we
+ // have added addr3.
+ test.rxf(t, s, e, addr1, snmc, 2)
+
+ if err := s.RemoveAddress(1, addr2); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err)
+ }
+
+ // Should still receive a packet destined to the
+ // solicited node address of addr2/addr3 now that we
+ // have removed addr2.
+ test.rxf(t, s, e, addr1, snmc, 3)
+
+ if err := s.RemoveAddress(1, addr3); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err)
+ }
+
+ // Should not receive a packet destined to the solicited
+ // node address of addr2/addr3 yet as both of them got
+ // removed.
+ test.rxf(t, s, e, addr1, snmc, 3)
+ })
+ }
+}
+
+// TestAddIpv6Address tests adding IPv6 addresses.
+func TestAddIpv6Address(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ }{
+ // This test is in response to b/140943433.
+ {
+ "Nil",
+ tcpip.Address([]byte(nil)),
+ },
+ {
+ "ValidUnicast",
+ addr1,
+ },
+ {
+ "ValidLinkLocalUnicast",
+ lladdr0,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil {
+ t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err)
+ }
+
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+ if addr.Address != test.addr {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 8e4cf0e74..0dbce14a0 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
)
@@ -31,16 +32,18 @@ import (
func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) {
t.Helper()
- s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{})
- {
- id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{})
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
- if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err)
- }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ })
+
+ if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+ if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err)
}
+
{
subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr))))
if err != nil {
@@ -95,7 +98,9 @@ func TestHopLimitValidation(t *testing.T) {
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(r, hdr.View().ToVectorisedView())
+ ep.HandlePacket(r, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
}
types := []struct {
@@ -107,7 +112,7 @@ func TestHopLimitValidation(t *testing.T) {
{"RouterSolicit", header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.RouterSolicit
}},
- {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.RouterAdvert
}},
{"NeighborSolicit", header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
@@ -148,7 +153,7 @@ func TestHopLimitValidation(t *testing.T) {
// Receive the NDP packet with an invalid hop limit
// value.
- handleIPv6Payload(hdr, ndpHopLimit-1, ep, &r)
+ handleIPv6Payload(hdr, header.NDPHopLimit-1, ep, &r)
// Invalid count should have increased.
if got := invalid.Value(); got != 1 {
@@ -162,7 +167,7 @@ func TestHopLimitValidation(t *testing.T) {
}
// Receive the NDP packet with a valid hop limit value.
- handleIPv6Payload(hdr, ndpHopLimit, ep, &r)
+ handleIPv6Payload(hdr, header.NDPHopLimit, ep, &r)
// Rx count of NDP packet of type typ.typ should have
// increased.
@@ -177,3 +182,191 @@ func TestHopLimitValidation(t *testing.T) {
})
}
}
+
+// TestRouterAdvertValidation tests that when the NIC is configured to handle
+// NDP Router Advertisement packets, it validates the Router Advertisement
+// properly before handling them.
+func TestRouterAdvertValidation(t *testing.T) {
+ tests := []struct {
+ name string
+ src tcpip.Address
+ hopLimit uint8
+ code uint8
+ ndpPayload []byte
+ expectedSuccess bool
+ }{
+ {
+ "OK",
+ lladdr0,
+ 255,
+ 0,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ true,
+ },
+ {
+ "NonLinkLocalSourceAddr",
+ addr1,
+ 255,
+ 0,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ false,
+ },
+ {
+ "HopLimitNot255",
+ lladdr0,
+ 254,
+ 0,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ false,
+ },
+ {
+ "NonZeroCode",
+ lladdr0,
+ 255,
+ 1,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ false,
+ },
+ {
+ "NDPPayloadTooSmall",
+ lladdr0,
+ 255,
+ 0,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0,
+ },
+ false,
+ },
+ {
+ "OKWithOptions",
+ lladdr0,
+ 255,
+ 0,
+ []byte{
+ // RA payload
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+
+ // Option #1 (TargetLinkLayerAddress)
+ 2, 1, 0, 0, 0, 0, 0, 0,
+
+ // Option #2 (unrecognized)
+ 255, 1, 0, 0, 0, 0, 0, 0,
+
+ // Option #3 (PrefixInformation)
+ 3, 4, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ },
+ true,
+ },
+ {
+ "OptionWithZeroLength",
+ lladdr0,
+ 255,
+ 0,
+ []byte{
+ // RA payload
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+
+ // Option #1 (TargetLinkLayerAddress)
+ // Invalid as it has 0 length.
+ 2, 0, 0, 0, 0, 0, 0, 0,
+
+ // Option #2 (unrecognized)
+ 255, 1, 0, 0, 0, 0, 0, 0,
+
+ // Option #3 (PrefixInformation)
+ 3, 4, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ },
+ false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
+ pkt := header.ICMPv6(hdr.Prepend(icmpSize))
+ pkt.SetType(header.ICMPv6RouterAdvert)
+ pkt.SetCode(test.code)
+ copy(pkt.NDPPayload(), test.ndpPayload)
+ payloadLength := hdr.UsedLength()
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: test.hopLimit,
+ SrcAddr: test.src,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ rxRA := stats.RouterAdvert
+
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := rxRA.Value(); got != 0 {
+ t.Fatalf("got rxRA = %d, want = 0", got)
+ }
+
+ e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ if test.expectedSuccess {
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := rxRA.Value(); got != 1 {
+ t.Fatalf("got rxRA = %d, want = 1", got)
+ }
+
+ } else {
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ if got := rxRA.Value(); got != 0 {
+ t.Fatalf("got rxRA = %d, want = 0", got)
+ }
+ }
+ })
+ }
+}