summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network/ipv6
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2020-11-12 17:30:31 -0800
committergVisor bot <gvisor-bot@google.com>2020-11-12 17:33:21 -0800
commit1a972411b36b8ad2543d3ea614c92e60ccbdffab (patch)
tree43fd70a53b1ee47469ba3876920eaaee9863c813 /pkg/tcpip/network/ipv6
parentae7ab0a330aaa1676d1fe066e3f5ac5fe805ec1c (diff)
Move packet handling to NetworkEndpoint
The NIC should not hold network-layer state or logic - network packet handling/forwarding should be performed at the network layer instead of the NIC. Fixes #4688 PiperOrigin-RevId: 342166985
Diffstat (limited to 'pkg/tcpip/network/ipv6')
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go73
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go102
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go34
3 files changed, 128 insertions, 81 deletions
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 76013daa1..001b9d66a 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -144,6 +144,10 @@ func (*testInterface) Enabled() bool {
return true
}
+func (*testInterface) Promiscuous() bool {
+ return false
+}
+
func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
r := stack.Route{
NetProto: protocol,
@@ -174,13 +178,8 @@ func TestICMPCounts(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
UseNeighborCache: test.useNeighborCache,
})
- {
- if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_, _) = %s", err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -206,11 +205,12 @@ func TestICMPCounts(t *testing.T) {
t.Fatalf("ep.Enable(): %s", err)
}
- r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ addr := lladdr0.WithPrefix()
+ if ep, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
}
- defer r.Release()
var tllData [header.NDPLinkLayerAddressSize]byte
header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
@@ -279,10 +279,9 @@ func TestICMPCounts(t *testing.T) {
PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
}
@@ -290,7 +289,7 @@ func TestICMPCounts(t *testing.T) {
icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
handleIPv6Payload(icmp)
}
@@ -317,13 +316,8 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
UseNeighborCache: true,
})
- {
- if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_, _) = %s", err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -349,11 +343,12 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
t.Fatalf("ep.Enable(): %s", err)
}
- r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ addr := lladdr0.WithPrefix()
+ if ep, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
}
- defer r.Release()
var tllData [header.NDPLinkLayerAddressSize]byte
header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
@@ -422,10 +417,9 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
}
@@ -433,7 +427,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
handleIPv6Payload(icmp)
}
@@ -1775,17 +1769,15 @@ func TestCallsToNeighborCache(t *testing.T) {
t.Fatalf("ep.Enable(): %s", err)
}
- r, err := s.FindRoute(nicID, lladdr0, test.source, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ addr := lladdr0.WithPrefix()
+ if ep, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
}
- defer r.Release()
-
- // TODO(gvisor.dev/issue/4517): Remove the need for this manual patch.
- r.LocalAddress = test.destination
icmp := test.createPacket()
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, r.RemoteAddress, r.LocalAddress, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, test.source, test.destination, buffer.VectorisedView{}))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: header.IPv6MinimumSize,
Data: buffer.View(icmp).ToVectorisedView(),
@@ -1795,10 +1787,9 @@ func TestCallsToNeighborCache(t *testing.T) {
PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
- SrcAddr: r.RemoteAddress,
- DstAddr: r.LocalAddress,
+ SrcAddr: test.source,
+ DstAddr: test.destination,
})
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
// Confirm the endpoint calls the correct NUDHandler method.
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 0526190cc..38a0633bd 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -441,17 +441,13 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
e.addIPHeader(r, pkt, params)
- return e.writePacket(r, gso, pkt, params.Protocol)
-}
-func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- ipt := e.protocol.stack.IPTables()
- if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
- r.Stats().IP.IPTablesOutputDropped.Increment()
+ e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment()
return nil
}
@@ -467,24 +463,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- route.PopulatePacketInfo(pkt)
// Since we rewrote the packet but it is being routed back to us, we can
// safely assume the checksum is valid.
pkt.RXTransportChecksumValidated = true
- ep.HandlePacket(pkt)
+ ep.(*endpoint).handlePacket(pkt)
}
return nil
}
}
+ return e.writePacket(r, gso, pkt, params.Protocol, false /* headerIncluded */)
+}
+
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) *tcpip.Error {
if r.Loop&stack.PacketLoop != 0 {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- loopedR := r.MakeLoopedRoute()
- loopedR.PopulatePacketInfo(pkt)
- loopedR.Release()
- e.HandlePacket(pkt)
+ // If the packet was generated by the stack (not a raw/packet endpoint
+ // where a packet may be written with the header included), then we can
+ // safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = !headerIncluded
+ e.handlePacket(pkt)
}
}
if r.Loop&stack.PacketOut == 0 {
@@ -558,8 +557,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- ipt := e.protocol.stack.IPTables()
- dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
+ dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
@@ -584,9 +582,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- route.PopulatePacketInfo(pkt)
- ep.HandlePacket(pkt)
+ // Since we rewrote the packet but it is being routed back to us, we
+ // can safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = true
+ ep.(*endpoint).handlePacket(pkt)
}
n++
continue
@@ -640,16 +639,66 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return tcpip.ErrMalformedHeader
}
- return e.writePacket(r, nil /* gso */, pkt, proto)
+ return e.writePacket(r, nil /* gso */, pkt, proto, true /* headerIncluded */)
+}
+
+// forwardPacket attempts to forward a packet to its final destination.
+func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
+ h := header.IPv6(pkt.NetworkHeader().View())
+ dstAddr := h.DestinationAddress()
+
+ // Check if the destination is owned by the stack.
+ networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr)
+ if err == nil {
+ networkEndpoint.(*endpoint).handlePacket(pkt)
+ return nil
+ }
+ if err != tcpip.ErrBadAddress {
+ return err
+ }
+
+ r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+ return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ // We need to do a deep copy of the IP packet because
+ // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
+ // not own it.
+ Data: stack.PayloadSince(pkt.NetworkHeader()).ToVectorisedView(),
+ }))
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
+ stats := e.protocol.stack.Stats()
+ stats.IP.PacketsReceived.Increment()
+
if !e.isEnabled() {
+ stats.IP.DisabledPacketsReceived.Increment()
return
}
+ // Loopback traffic skips the prerouting chain.
+ if !e.nic.IsLoopback() {
+ if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok {
+ // iptables is telling us to drop the packet.
+ stats.IP.IPTablesPreroutingDropped.Increment()
+ return
+ }
+ }
+
+ e.handlePacket(pkt)
+}
+
+// handlePacket is like HandlePacket except it does not perform the prerouting
+// iptables hook.
+func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
stats := e.protocol.stack.Stats()
@@ -669,6 +718,18 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
+ addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
+ if addressEndpoint == nil {
+ if !e.protocol.Forwarding() {
+ stats.IP.InvalidDestinationAddressesReceived.Increment()
+ return
+ }
+
+ _ = e.forwardPacket(pkt)
+ return
+ }
+ addressEndpoint.DecRef()
+
// vv consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions).
// - The transport header, if present.
@@ -681,8 +742,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// iptables filtering. All packets that reach here are intended for
// this machine and need not be forwarded.
- ipt := e.protocol.stack.IPTables()
- if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
stats.IP.IPTablesInputDropped.Increment()
return
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 981d1371a..be83e9eb4 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -45,10 +45,6 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig
if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
}
- if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err)
- }
-
{
subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr))))
if err != nil {
@@ -73,6 +69,13 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig
}
t.Cleanup(ep.Close)
+ addr := llladdr.WithPrefix()
+ if addressEP, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ addressEP.DecRef()
+ }
+
return s, ep
}
@@ -961,22 +964,17 @@ func TestNDPValidation(t *testing.T) {
for _, stackTyp := range stacks {
t.Run(stackTyp.name, func(t *testing.T) {
- setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
+ setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) {
t.Helper()
// Create a stack with the assigned link-local address lladdr0
// and an endpoint to lladdr1.
s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache)
- r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
- }
-
- return s, ep, r
+ return s, ep
}
- handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
+ handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) {
nextHdr := uint8(header.ICMPv6ProtocolNumber)
var extensions buffer.View
if atomicFragment {
@@ -994,13 +992,12 @@ func TestNDPValidation(t *testing.T) {
PayloadLength: uint16(len(payload) + len(extensions)),
NextHeader: nextHdr,
HopLimit: hopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
}
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
}
@@ -1114,8 +1111,7 @@ func TestNDPValidation(t *testing.T) {
t.Run(name, func(t *testing.T) {
for _, test := range subTests {
t.Run(test.name, func(t *testing.T) {
- s, ep, r := setup(t)
- defer r.Release()
+ s, ep := setup(t)
if isRouter {
// Enabling forwarding makes the stack act as a router.
@@ -1131,7 +1127,7 @@ func TestNDPValidation(t *testing.T) {
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
icmp.SetCode(test.code)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
// Rx count of the NDP message should initially be 0.
if got := typStat.Value(); got != 0 {
@@ -1152,7 +1148,7 @@ func TestNDPValidation(t *testing.T) {
t.FailNow()
}
- handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r)
+ handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep)
// Rx count of the NDP packet should have increased.
if got := typStat.Value(); got != 1 {