summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2020-11-12 17:30:31 -0800
committergVisor bot <gvisor-bot@google.com>2020-11-12 17:33:21 -0800
commit1a972411b36b8ad2543d3ea614c92e60ccbdffab (patch)
tree43fd70a53b1ee47469ba3876920eaaee9863c813 /pkg/tcpip/network
parentae7ab0a330aaa1676d1fe066e3f5ac5fe805ec1c (diff)
Move packet handling to NetworkEndpoint
The NIC should not hold network-layer state or logic - network packet handling/forwarding should be performed at the network layer instead of the NIC. Fixes #4688 PiperOrigin-RevId: 342166985
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r--pkg/tcpip/network/BUILD1
-rw-r--r--pkg/tcpip/network/arp/arp_test.go4
-rw-r--r--pkg/tcpip/network/ip_test.go289
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go106
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go73
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go102
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go34
7 files changed, 376 insertions, 233 deletions
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index c118a2929..b38aff0b8 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -14,6 +14,7 @@ go_test(
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 087ee9c66..91d59d83f 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -439,6 +439,10 @@ func (*testInterface) Enabled() bool {
return true
}
+func (*testInterface) Promiscuous() bool {
+ return false
+}
+
func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
r := stack.Route{
NetProto: protocol,
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index c7d26e14f..576792234 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -34,16 +35,16 @@ import (
)
const (
- localIPv4Addr = "\x0a\x00\x00\x01"
- remoteIPv4Addr = "\x0a\x00\x00\x02"
- ipv4SubnetAddr = "\x0a\x00\x00\x00"
- ipv4SubnetMask = "\xff\xff\xff\x00"
- ipv4Gateway = "\x0a\x00\x00\x03"
- localIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- remoteIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
- ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
- ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
+ localIPv4Addr = tcpip.Address("\x0a\x00\x00\x01")
+ remoteIPv4Addr = tcpip.Address("\x0a\x00\x00\x02")
+ ipv4SubnetAddr = tcpip.Address("\x0a\x00\x00\x00")
+ ipv4SubnetMask = tcpip.Address("\xff\xff\xff\x00")
+ ipv4Gateway = tcpip.Address("\x0a\x00\x00\x03")
+ localIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ remoteIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ ipv6SubnetAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00")
+ ipv6SubnetMask = tcpip.Address("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00")
+ ipv6Gateway = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
nicID = 1
)
@@ -299,6 +300,10 @@ func (t *testInterface) Enabled() bool {
return !t.mu.disabled
}
+func (*testInterface) Promiscuous() bool {
+ return false
+}
+
func (t *testInterface) setEnabled(v bool) {
t.mu.Lock()
defer t.mu.Unlock()
@@ -558,59 +563,131 @@ func TestIPv4Send(t *testing.T) {
}
}
-func TestIPv4Receive(t *testing.T) {
- s := buildDummyStack(t)
- proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- nic := testInterface{
- testObject: testObject{
- t: t,
- v4: true,
+func TestReceive(t *testing.T) {
+ tests := []struct {
+ name string
+ protoFactory stack.NetworkProtocolFactory
+ protoNum tcpip.NetworkProtocolNumber
+ v4 bool
+ epAddr tcpip.AddressWithPrefix
+ handlePacket func(*testing.T, stack.NetworkEndpoint, *testInterface)
+ }{
+ {
+ name: "IPv4",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ v4: true,
+ epAddr: localIPv4Addr.WithPrefix(),
+ handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) {
+ const totalLen = header.IPv4MinimumSize + 30 /* payload length */
+
+ view := buffer.NewView(totalLen)
+ ip := header.IPv4(view)
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: totalLen,
+ TTL: ipv4.DefaultTTL,
+ Protocol: 10,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: localIPv4Addr,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ // Make payload be non-zero.
+ for i := header.IPv4MinimumSize; i < len(view); i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = view[header.IPv4MinimumSize:totalLen]
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: view.ToVectorisedView(),
+ })
+ if ok := parse.IPv4(pkt); !ok {
+ t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
+ }
+ ep.HandlePacket(pkt)
+ },
},
- }
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
- defer ep.Close()
+ {
+ name: "IPv6",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ v4: false,
+ epAddr: localIPv6Addr.WithPrefix(),
+ handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) {
+ const payloadLen = 30
+ view := buffer.NewView(header.IPv6MinimumSize + payloadLen)
+ ip := header.IPv6(view)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: payloadLen,
+ NextHeader: 10,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: localIPv6Addr,
+ })
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
+ // Make payload be non-zero.
+ for i := header.IPv6MinimumSize; i < len(view); i++ {
+ view[i] = uint8(i)
+ }
- totalLen := header.IPv4MinimumSize + 30
- view := buffer.NewView(totalLen)
- ip := header.IPv4(view)
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- TTL: 20,
- Protocol: 10,
- SrcAddr: remoteIPv4Addr,
- DstAddr: localIPv4Addr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
+ // Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv6Addr
+ nic.testObject.dstAddr = localIPv6Addr
+ nic.testObject.contents = view[header.IPv6MinimumSize:][:payloadLen]
- // Make payload be non-zero.
- for i := header.IPv4MinimumSize; i < totalLen; i++ {
- view[i] = uint8(i)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: view.ToVectorisedView(),
+ })
+ if _, _, _, _, ok := parse.IPv6(pkt); !ok {
+ t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
+ }
+ ep.HandlePacket(pkt)
+ },
+ },
}
- // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- nic.testObject.protocol = 10
- nic.testObject.srcAddr = remoteIPv4Addr
- nic.testObject.dstAddr = localIPv4Addr
- nic.testObject.contents = view[header.IPv4MinimumSize:totalLen]
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory},
+ })
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ v4: test.v4,
+ },
+ }
+ ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, nil, nil, &nic.testObject)
+ defer ep.Close()
- r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
- if err != nil {
- t.Fatalf("could not find route: %v", err)
- }
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: view.ToVectorisedView(),
- })
- if _, _, ok := proto.Parse(pkt); !ok {
- t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
- }
- r.PopulatePacketInfo(pkt)
- ep.HandlePacket(pkt)
- if nic.testObject.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
+ if ep, err := ep.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err)
+ } else {
+ ep.DecRef()
+ }
+
+ stat := s.Stats().IP.PacketsReceived
+ if got := stat.Value(); got != 0 {
+ t.Fatalf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 0", got)
+ }
+ test.handlePacket(t, ep, &nic)
+ if nic.testObject.dataCalls != 1 {
+ t.Errorf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
+ }
+ if got := stat.Value(); got != 1 {
+ t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got)
+ }
+ })
}
}
@@ -634,10 +711,6 @@ func TestIPv4ReceiveControl(t *testing.T) {
{"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
{"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8},
}
- r, err := buildIPv4Route(localIPv4Addr, "\x0a\x00\x00\xbb")
- if err != nil {
- t.Fatal(err)
- }
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
s := buildDummyStack(t)
@@ -705,8 +778,14 @@ func TestIPv4ReceiveControl(t *testing.T) {
nic.testObject.typ = c.expectedTyp
nic.testObject.extra = c.expectedExtra
+ addr := localIPv4Addr.WithPrefix()
+ if ep, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
+ }
+
pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize)
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
if want := c.expectedCount; nic.testObject.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
@@ -716,7 +795,9 @@ func TestIPv4ReceiveControl(t *testing.T) {
}
func TestIPv4FragmentationReceive(t *testing.T) {
- s := buildDummyStack(t)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ })
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
nic := testInterface{
testObject: testObject{
@@ -774,11 +855,6 @@ func TestIPv4FragmentationReceive(t *testing.T) {
nic.testObject.dstAddr = localIPv4Addr
nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
- r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
- if err != nil {
- t.Fatalf("could not find route: %v", err)
- }
-
// Send first segment.
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: frag1.ToVectorisedView(),
@@ -786,7 +862,14 @@ func TestIPv4FragmentationReceive(t *testing.T) {
if _, _, ok := proto.Parse(pkt); !ok {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
- r.PopulatePacketInfo(pkt)
+
+ addr := localIPv4Addr.WithPrefix()
+ if ep, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
+ }
+
ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 0 {
t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls)
@@ -799,7 +882,6 @@ func TestIPv4FragmentationReceive(t *testing.T) {
if _, _, ok := proto.Parse(pkt); !ok {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
@@ -852,61 +934,6 @@ func TestIPv6Send(t *testing.T) {
}
}
-func TestIPv6Receive(t *testing.T) {
- s := buildDummyStack(t)
- proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
- nic := testInterface{
- testObject: testObject{
- t: t,
- },
- }
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
- defer ep.Close()
-
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
-
- totalLen := header.IPv6MinimumSize + 30
- view := buffer.NewView(totalLen)
- ip := header.IPv6(view)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(totalLen - header.IPv6MinimumSize),
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: remoteIPv6Addr,
- DstAddr: localIPv6Addr,
- })
-
- // Make payload be non-zero.
- for i := header.IPv6MinimumSize; i < totalLen; i++ {
- view[i] = uint8(i)
- }
-
- // Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
- nic.testObject.protocol = 10
- nic.testObject.srcAddr = remoteIPv6Addr
- nic.testObject.dstAddr = localIPv6Addr
- nic.testObject.contents = view[header.IPv6MinimumSize:totalLen]
-
- r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
- if err != nil {
- t.Fatalf("could not find route: %v", err)
- }
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: view.ToVectorisedView(),
- })
- if _, _, ok := proto.Parse(pkt); !ok {
- t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
- }
- r.PopulatePacketInfo(pkt)
- ep.HandlePacket(pkt)
- if nic.testObject.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
- }
-}
-
func TestIPv6ReceiveControl(t *testing.T) {
newUint16 := func(v uint16) *uint16 { return &v }
@@ -933,13 +960,6 @@ func TestIPv6ReceiveControl(t *testing.T) {
{"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
{"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8},
}
- r, err := buildIPv6Route(
- localIPv6Addr,
- "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
- )
- if err != nil {
- t.Fatal(err)
- }
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
s := buildDummyStack(t)
@@ -1018,8 +1038,13 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Set ICMPv6 checksum.
icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{}))
+ addr := localIPv6Addr.WithPrefix()
+ if ep, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
+ }
pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize)
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
if want := c.expectedCount; nic.testObject.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index a376cb8ec..0af646df9 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -260,16 +260,13 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
e.addIPHeader(r, pkt, params)
- return e.writePacket(r, gso, pkt)
-}
-func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer) *tcpip.Error {
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
- r.Stats().IP.IPTablesOutputDropped.Increment()
+ e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment()
return nil
}
@@ -286,24 +283,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
if err == nil {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- route.PopulatePacketInfo(pkt)
// Since we rewrote the packet but it is being routed back to us, we can
// safely assume the checksum is valid.
pkt.RXTransportChecksumValidated = true
- ep.HandlePacket(pkt)
+ ep.(*endpoint).handlePacket(pkt)
}
return nil
}
}
+ return e.writePacket(r, gso, pkt, false /* headerIncluded */)
+}
+
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) *tcpip.Error {
if r.Loop&stack.PacketLoop != 0 {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- loopedR := r.MakeLoopedRoute()
- loopedR.PopulatePacketInfo(pkt)
- loopedR.Release()
- e.HandlePacket(pkt)
+ // If the packet was generated by the stack (not a raw/packet endpoint
+ // where a packet may be written with the header included), then we can
+ // safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = !headerIncluded
+ e.handlePacket(pkt)
}
}
if r.Loop&stack.PacketOut == 0 {
@@ -374,8 +374,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- ipt := e.protocol.stack.IPTables()
- dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
+ dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
@@ -400,9 +399,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- route.PopulatePacketInfo(pkt)
- ep.HandlePacket(pkt)
+ // Since we rewrote the packet but it is being routed back to us, we
+ // can safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = true
+ ep.(*endpoint).handlePacket(pkt)
}
n++
continue
@@ -479,16 +479,66 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return tcpip.ErrMalformedHeader
}
- return e.writePacket(r, nil /* gso */, pkt)
+ return e.writePacket(r, nil /* gso */, pkt, true /* headerIncluded */)
+}
+
+// forwardPacket attempts to forward a packet to its final destination.
+func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
+ h := header.IPv4(pkt.NetworkHeader().View())
+ dstAddr := h.DestinationAddress()
+
+ // Check if the destination is owned by the stack.
+ networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr)
+ if err == nil {
+ networkEndpoint.(*endpoint).handlePacket(pkt)
+ return nil
+ }
+ if err != tcpip.ErrBadAddress {
+ return err
+ }
+
+ r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+ return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ // We need to do a deep copy of the IP packet because
+ // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
+ // not own it.
+ Data: stack.PayloadSince(pkt.NetworkHeader()).ToVectorisedView(),
+ }))
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
+ stats := e.protocol.stack.Stats()
+ stats.IP.PacketsReceived.Increment()
+
if !e.isEnabled() {
+ stats.IP.DisabledPacketsReceived.Increment()
return
}
+ // Loopback traffic skips the prerouting chain.
+ if !e.nic.IsLoopback() {
+ if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok {
+ // iptables is telling us to drop the packet.
+ stats.IP.IPTablesPreroutingDropped.Increment()
+ return
+ }
+ }
+
+ e.handlePacket(pkt)
+}
+
+// handlePacket is like HandlePacket except it does not perform the prerouting
+// iptables hook.
+func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
stats := e.protocol.stack.Stats()
@@ -497,6 +547,21 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
stats.IP.MalformedPacketsReceived.Increment()
return
}
+ srcAddr := h.SourceAddress()
+ dstAddr := h.DestinationAddress()
+
+ addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
+ if addressEndpoint == nil {
+ if !e.protocol.Forwarding() {
+ stats.IP.InvalidDestinationAddressesReceived.Increment()
+ return
+ }
+
+ _ = e.forwardPacket(pkt)
+ return
+ }
+ subnet := addressEndpoint.AddressWithPrefix().Subnet()
+ addressEndpoint.DecRef()
// There has been some confusion regarding verifying checksums. We need
// just look for negative 0 (0xffff) as the checksum, as it's not possible to
@@ -528,15 +593,16 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// When a host sends any datagram, the IP source address MUST
// be one of its own IP addresses (but not a broadcast or
// multicast address).
- if pkt.NetworkPacketInfo.RemoteAddressBroadcast || header.IsV4MulticastAddress(h.SourceAddress()) {
+ if directedBroadcast := subnet.IsBroadcast(srcAddr); directedBroadcast || srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) {
stats.IP.InvalidSourceAddressesReceived.Increment()
return
}
+ pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast
+
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- ipt := e.protocol.stack.IPTables()
- if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
stats.IP.IPTablesInputDropped.Increment()
return
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 76013daa1..001b9d66a 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -144,6 +144,10 @@ func (*testInterface) Enabled() bool {
return true
}
+func (*testInterface) Promiscuous() bool {
+ return false
+}
+
func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
r := stack.Route{
NetProto: protocol,
@@ -174,13 +178,8 @@ func TestICMPCounts(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
UseNeighborCache: test.useNeighborCache,
})
- {
- if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_, _) = %s", err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -206,11 +205,12 @@ func TestICMPCounts(t *testing.T) {
t.Fatalf("ep.Enable(): %s", err)
}
- r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ addr := lladdr0.WithPrefix()
+ if ep, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
}
- defer r.Release()
var tllData [header.NDPLinkLayerAddressSize]byte
header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
@@ -279,10 +279,9 @@ func TestICMPCounts(t *testing.T) {
PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
}
@@ -290,7 +289,7 @@ func TestICMPCounts(t *testing.T) {
icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
handleIPv6Payload(icmp)
}
@@ -317,13 +316,8 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
UseNeighborCache: true,
})
- {
- if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_, _) = %s", err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -349,11 +343,12 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
t.Fatalf("ep.Enable(): %s", err)
}
- r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ addr := lladdr0.WithPrefix()
+ if ep, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
}
- defer r.Release()
var tllData [header.NDPLinkLayerAddressSize]byte
header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
@@ -422,10 +417,9 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
}
@@ -433,7 +427,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
handleIPv6Payload(icmp)
}
@@ -1775,17 +1769,15 @@ func TestCallsToNeighborCache(t *testing.T) {
t.Fatalf("ep.Enable(): %s", err)
}
- r, err := s.FindRoute(nicID, lladdr0, test.source, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ addr := lladdr0.WithPrefix()
+ if ep, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
}
- defer r.Release()
-
- // TODO(gvisor.dev/issue/4517): Remove the need for this manual patch.
- r.LocalAddress = test.destination
icmp := test.createPacket()
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, r.RemoteAddress, r.LocalAddress, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, test.source, test.destination, buffer.VectorisedView{}))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: header.IPv6MinimumSize,
Data: buffer.View(icmp).ToVectorisedView(),
@@ -1795,10 +1787,9 @@ func TestCallsToNeighborCache(t *testing.T) {
PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
- SrcAddr: r.RemoteAddress,
- DstAddr: r.LocalAddress,
+ SrcAddr: test.source,
+ DstAddr: test.destination,
})
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
// Confirm the endpoint calls the correct NUDHandler method.
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 0526190cc..38a0633bd 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -441,17 +441,13 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
e.addIPHeader(r, pkt, params)
- return e.writePacket(r, gso, pkt, params.Protocol)
-}
-func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- ipt := e.protocol.stack.IPTables()
- if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
- r.Stats().IP.IPTablesOutputDropped.Increment()
+ e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment()
return nil
}
@@ -467,24 +463,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- route.PopulatePacketInfo(pkt)
// Since we rewrote the packet but it is being routed back to us, we can
// safely assume the checksum is valid.
pkt.RXTransportChecksumValidated = true
- ep.HandlePacket(pkt)
+ ep.(*endpoint).handlePacket(pkt)
}
return nil
}
}
+ return e.writePacket(r, gso, pkt, params.Protocol, false /* headerIncluded */)
+}
+
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) *tcpip.Error {
if r.Loop&stack.PacketLoop != 0 {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- loopedR := r.MakeLoopedRoute()
- loopedR.PopulatePacketInfo(pkt)
- loopedR.Release()
- e.HandlePacket(pkt)
+ // If the packet was generated by the stack (not a raw/packet endpoint
+ // where a packet may be written with the header included), then we can
+ // safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = !headerIncluded
+ e.handlePacket(pkt)
}
}
if r.Loop&stack.PacketOut == 0 {
@@ -558,8 +557,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- ipt := e.protocol.stack.IPTables()
- dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
+ dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
@@ -584,9 +582,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- route.PopulatePacketInfo(pkt)
- ep.HandlePacket(pkt)
+ // Since we rewrote the packet but it is being routed back to us, we
+ // can safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = true
+ ep.(*endpoint).handlePacket(pkt)
}
n++
continue
@@ -640,16 +639,66 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return tcpip.ErrMalformedHeader
}
- return e.writePacket(r, nil /* gso */, pkt, proto)
+ return e.writePacket(r, nil /* gso */, pkt, proto, true /* headerIncluded */)
+}
+
+// forwardPacket attempts to forward a packet to its final destination.
+func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
+ h := header.IPv6(pkt.NetworkHeader().View())
+ dstAddr := h.DestinationAddress()
+
+ // Check if the destination is owned by the stack.
+ networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr)
+ if err == nil {
+ networkEndpoint.(*endpoint).handlePacket(pkt)
+ return nil
+ }
+ if err != tcpip.ErrBadAddress {
+ return err
+ }
+
+ r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+ return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ // We need to do a deep copy of the IP packet because
+ // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
+ // not own it.
+ Data: stack.PayloadSince(pkt.NetworkHeader()).ToVectorisedView(),
+ }))
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
+ stats := e.protocol.stack.Stats()
+ stats.IP.PacketsReceived.Increment()
+
if !e.isEnabled() {
+ stats.IP.DisabledPacketsReceived.Increment()
return
}
+ // Loopback traffic skips the prerouting chain.
+ if !e.nic.IsLoopback() {
+ if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok {
+ // iptables is telling us to drop the packet.
+ stats.IP.IPTablesPreroutingDropped.Increment()
+ return
+ }
+ }
+
+ e.handlePacket(pkt)
+}
+
+// handlePacket is like HandlePacket except it does not perform the prerouting
+// iptables hook.
+func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
stats := e.protocol.stack.Stats()
@@ -669,6 +718,18 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
+ addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
+ if addressEndpoint == nil {
+ if !e.protocol.Forwarding() {
+ stats.IP.InvalidDestinationAddressesReceived.Increment()
+ return
+ }
+
+ _ = e.forwardPacket(pkt)
+ return
+ }
+ addressEndpoint.DecRef()
+
// vv consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions).
// - The transport header, if present.
@@ -681,8 +742,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// iptables filtering. All packets that reach here are intended for
// this machine and need not be forwarded.
- ipt := e.protocol.stack.IPTables()
- if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
stats.IP.IPTablesInputDropped.Increment()
return
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 981d1371a..be83e9eb4 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -45,10 +45,6 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig
if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
}
- if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err)
- }
-
{
subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr))))
if err != nil {
@@ -73,6 +69,13 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig
}
t.Cleanup(ep.Close)
+ addr := llladdr.WithPrefix()
+ if addressEP, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ addressEP.DecRef()
+ }
+
return s, ep
}
@@ -961,22 +964,17 @@ func TestNDPValidation(t *testing.T) {
for _, stackTyp := range stacks {
t.Run(stackTyp.name, func(t *testing.T) {
- setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
+ setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) {
t.Helper()
// Create a stack with the assigned link-local address lladdr0
// and an endpoint to lladdr1.
s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache)
- r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
- }
-
- return s, ep, r
+ return s, ep
}
- handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
+ handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) {
nextHdr := uint8(header.ICMPv6ProtocolNumber)
var extensions buffer.View
if atomicFragment {
@@ -994,13 +992,12 @@ func TestNDPValidation(t *testing.T) {
PayloadLength: uint16(len(payload) + len(extensions)),
NextHeader: nextHdr,
HopLimit: hopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
}
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
}
@@ -1114,8 +1111,7 @@ func TestNDPValidation(t *testing.T) {
t.Run(name, func(t *testing.T) {
for _, test := range subTests {
t.Run(test.name, func(t *testing.T) {
- s, ep, r := setup(t)
- defer r.Release()
+ s, ep := setup(t)
if isRouter {
// Enabling forwarding makes the stack act as a router.
@@ -1131,7 +1127,7 @@ func TestNDPValidation(t *testing.T) {
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
icmp.SetCode(test.code)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
// Rx count of the NDP message should initially be 0.
if got := typStat.Value(); got != 0 {
@@ -1152,7 +1148,7 @@ func TestNDPValidation(t *testing.T) {
t.FailNow()
}
- handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r)
+ handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep)
// Rx count of the NDP packet should have increased.
if got := typStat.Value(); got != 1 {