diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2020-11-12 17:30:31 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-11-12 17:33:21 -0800 |
commit | 1a972411b36b8ad2543d3ea614c92e60ccbdffab (patch) | |
tree | 43fd70a53b1ee47469ba3876920eaaee9863c813 /pkg/tcpip/network/ipv6 | |
parent | ae7ab0a330aaa1676d1fe066e3f5ac5fe805ec1c (diff) |
Move packet handling to NetworkEndpoint
The NIC should not hold network-layer state or logic - network packet
handling/forwarding should be performed at the network layer instead
of the NIC.
Fixes #4688
PiperOrigin-RevId: 342166985
Diffstat (limited to 'pkg/tcpip/network/ipv6')
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp_test.go | 73 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 102 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp_test.go | 34 |
3 files changed, 128 insertions, 81 deletions
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 76013daa1..001b9d66a 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -144,6 +144,10 @@ func (*testInterface) Enabled() bool { return true } +func (*testInterface) Promiscuous() bool { + return false +} + func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { r := stack.Route{ NetProto: protocol, @@ -174,13 +178,8 @@ func TestICMPCounts(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, UseNeighborCache: test.useNeighborCache, }) - { - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -206,11 +205,12 @@ func TestICMPCounts(t *testing.T) { t.Fatalf("ep.Enable(): %s", err) } - r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + addr := lladdr0.WithPrefix() + if ep, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() } - defer r.Release() var tllData [header.NDPLinkLayerAddressSize]byte header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ @@ -279,10 +279,9 @@ func TestICMPCounts(t *testing.T) { PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) } @@ -290,7 +289,7 @@ func TestICMPCounts(t *testing.T) { icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) handleIPv6Payload(icmp) } @@ -317,13 +316,8 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, UseNeighborCache: true, }) - { - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -349,11 +343,12 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { t.Fatalf("ep.Enable(): %s", err) } - r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + addr := lladdr0.WithPrefix() + if ep, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() } - defer r.Release() var tllData [header.NDPLinkLayerAddressSize]byte header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ @@ -422,10 +417,9 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) } @@ -433,7 +427,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) handleIPv6Payload(icmp) } @@ -1775,17 +1769,15 @@ func TestCallsToNeighborCache(t *testing.T) { t.Fatalf("ep.Enable(): %s", err) } - r, err := s.FindRoute(nicID, lladdr0, test.source, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + addr := lladdr0.WithPrefix() + if ep, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() } - defer r.Release() - - // TODO(gvisor.dev/issue/4517): Remove the need for this manual patch. - r.LocalAddress = test.destination icmp := test.createPacket() - icmp.SetChecksum(header.ICMPv6Checksum(icmp, r.RemoteAddress, r.LocalAddress, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, test.source, test.destination, buffer.VectorisedView{})) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: header.IPv6MinimumSize, Data: buffer.View(icmp).ToVectorisedView(), @@ -1795,10 +1787,9 @@ func TestCallsToNeighborCache(t *testing.T) { PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, - SrcAddr: r.RemoteAddress, - DstAddr: r.LocalAddress, + SrcAddr: test.source, + DstAddr: test.destination, }) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) // Confirm the endpoint calls the correct NUDHandler method. diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 0526190cc..38a0633bd 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -441,17 +441,13 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { e.addIPHeader(r, pkt, params) - return e.writePacket(r, gso, pkt, params.Protocol) -} -func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber) *tcpip.Error { // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. - r.Stats().IP.IPTablesOutputDropped.Increment() + e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment() return nil } @@ -467,24 +463,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - route.PopulatePacketInfo(pkt) // Since we rewrote the packet but it is being routed back to us, we can // safely assume the checksum is valid. pkt.RXTransportChecksumValidated = true - ep.HandlePacket(pkt) + ep.(*endpoint).handlePacket(pkt) } return nil } } + return e.writePacket(r, gso, pkt, params.Protocol, false /* headerIncluded */) +} + +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) *tcpip.Error { if r.Loop&stack.PacketLoop != 0 { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - loopedR := r.MakeLoopedRoute() - loopedR.PopulatePacketInfo(pkt) - loopedR.Release() - e.HandlePacket(pkt) + // If the packet was generated by the stack (not a raw/packet endpoint + // where a packet may be written with the header included), then we can + // safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = !headerIncluded + e.handlePacket(pkt) } } if r.Loop&stack.PacketOut == 0 { @@ -558,8 +557,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - ipt := e.protocol.stack.IPTables() - dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) + dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. @@ -584,9 +582,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - route.PopulatePacketInfo(pkt) - ep.HandlePacket(pkt) + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = true + ep.(*endpoint).handlePacket(pkt) } n++ continue @@ -640,16 +639,66 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return tcpip.ErrMalformedHeader } - return e.writePacket(r, nil /* gso */, pkt, proto) + return e.writePacket(r, nil /* gso */, pkt, proto, true /* headerIncluded */) +} + +// forwardPacket attempts to forward a packet to its final destination. +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { + h := header.IPv6(pkt.NetworkHeader().View()) + dstAddr := h.DestinationAddress() + + // Check if the destination is owned by the stack. + networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr) + if err == nil { + networkEndpoint.(*endpoint).handlePacket(pkt) + return nil + } + if err != tcpip.ErrBadAddress { + return err + } + + r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer r.Release() + + // TODO(b/143425874) Decrease the TTL field in forwarded packets. + return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + // We need to do a deep copy of the IP packet because + // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do + // not own it. + Data: stack.PayloadSince(pkt.NetworkHeader()).ToVectorisedView(), + })) } // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { + stats := e.protocol.stack.Stats() + stats.IP.PacketsReceived.Increment() + if !e.isEnabled() { + stats.IP.DisabledPacketsReceived.Increment() return } + // Loopback traffic skips the prerouting chain. + if !e.nic.IsLoopback() { + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { + // iptables is telling us to drop the packet. + stats.IP.IPTablesPreroutingDropped.Increment() + return + } + } + + e.handlePacket(pkt) +} + +// handlePacket is like HandlePacket except it does not perform the prerouting +// iptables hook. +func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.protocol.stack.Stats() @@ -669,6 +718,18 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } + addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) + if addressEndpoint == nil { + if !e.protocol.Forwarding() { + stats.IP.InvalidDestinationAddressesReceived.Increment() + return + } + + _ = e.forwardPacket(pkt) + return + } + addressEndpoint.DecRef() + // vv consists of: // - Any IPv6 header bytes after the first 40 (i.e. extensions). // - The transport header, if present. @@ -681,8 +742,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. stats.IP.IPTablesInputDropped.Increment() return diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 981d1371a..be83e9eb4 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -45,10 +45,6 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) } - if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) - } - { subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) if err != nil { @@ -73,6 +69,13 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig } t.Cleanup(ep.Close) + addr := llladdr.WithPrefix() + if addressEP, err := ep.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("ep.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + addressEP.DecRef() + } + return s, ep } @@ -961,22 +964,17 @@ func TestNDPValidation(t *testing.T) { for _, stackTyp := range stacks { t.Run(stackTyp.name, func(t *testing.T) { - setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) { + setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) { t.Helper() // Create a stack with the assigned link-local address lladdr0 // and an endpoint to lladdr1. s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache) - r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) - } - - return s, ep, r + return s, ep } - handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { + handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { nextHdr := uint8(header.ICMPv6ProtocolNumber) var extensions buffer.View if atomicFragment { @@ -994,13 +992,12 @@ func TestNDPValidation(t *testing.T) { PayloadLength: uint16(len(payload) + len(extensions)), NextHeader: nextHdr, HopLimit: hopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) } - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) } @@ -1114,8 +1111,7 @@ func TestNDPValidation(t *testing.T) { t.Run(name, func(t *testing.T) { for _, test := range subTests { t.Run(test.name, func(t *testing.T) { - s, ep, r := setup(t) - defer r.Release() + s, ep := setup(t) if isRouter { // Enabling forwarding makes the stack act as a router. @@ -1131,7 +1127,7 @@ func TestNDPValidation(t *testing.T) { copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) icmp.SetCode(test.code) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) // Rx count of the NDP message should initially be 0. if got := typStat.Value(); got != 0 { @@ -1152,7 +1148,7 @@ func TestNDPValidation(t *testing.T) { t.FailNow() } - handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r) + handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep) // Rx count of the NDP packet should have increased. if got := typStat.Value(); got != 1 { |