diff options
Diffstat (limited to 'pkg/tcpip/stack/nic.go')
-rw-r--r-- | pkg/tcpip/stack/nic.go | 73 |
1 files changed, 70 insertions, 3 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 38d066cd1..dbd304b7e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -40,6 +40,9 @@ type NIC struct { endpoints map[NetworkEndpointID]*referencedNetworkEndpoint addressRanges []tcpip.Subnet mcastJoins map[NetworkEndpointID]int32 + // packetEPs is protected by mu, but the contained PacketEndpoint + // values are not. + packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint stats NICStats @@ -78,7 +81,7 @@ const ( ) func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC { - return &NIC{ + nic := &NIC{ stack: stack, id: id, name: name, @@ -87,6 +90,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint), endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), mcastJoins: make(map[NetworkEndpointID]int32), + packetEPs: make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint), stats: NICStats{ Tx: DirectionStats{ Packets: &tcpip.StatCounter{}, @@ -101,6 +105,16 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback dad: make(map[tcpip.Address]dadState), }, } + + // Register supported packet endpoint protocols. + for _, netProto := range header.Ethertypes { + nic.packetEPs[netProto] = []PacketEndpoint{} + } + for _, netProto := range stack.networkProtocols { + nic.packetEPs[netProto.Number()] = []PacketEndpoint{} + } + + return nic } // enable enables the NIC. enable will attach the link to its LinkEndpoint and @@ -631,7 +645,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. -func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { n.stats.Rx.Packets.Increment() n.stats.Rx.Bytes.IncrementBy(uint64(vv.Size())) @@ -641,6 +655,26 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr return } + // If no local link layer address is provided, assume it was sent + // directly to this NIC. + if local == "" { + local = n.linkEP.LinkAddress() + } + + // Are any packet sockets listening for this network protocol? + n.mu.RLock() + packetEPs := n.packetEPs[protocol] + // Check whether there are packet sockets listening for every protocol. + // If we received a packet with protocol EthernetProtocolAll, then the + // previous for loop will have handled it. + if protocol != header.EthernetProtocolAll { + packetEPs = append(packetEPs, n.packetEPs[header.EthernetProtocolAll]...) + } + n.mu.RUnlock() + for _, ep := range packetEPs { + ep.HandlePacket(n.id, local, protocol, vv, linkHeader) + } + if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { n.stack.stats.IP.PacketsReceived.Increment() } @@ -700,7 +734,10 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr return } - n.stack.stats.IP.InvalidAddressesReceived.Increment() + // If a packet socket handled the packet, don't treat it as invalid. + if len(packetEPs) == 0 { + n.stack.stats.IP.InvalidAddressesReceived.Increment() + } } // DeliverTransportPacket delivers the packets to the appropriate transport @@ -856,6 +893,36 @@ const ( temporary ) +func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + + eps, ok := n.packetEPs[netProto] + if !ok { + return tcpip.ErrNotSupported + } + n.packetEPs[netProto] = append(eps, ep) + + return nil +} + +func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) { + n.mu.Lock() + defer n.mu.Unlock() + + eps, ok := n.packetEPs[netProto] + if !ok { + return + } + + for i, epOther := range eps { + if epOther == ep { + n.packetEPs[netProto] = append(eps[:i], eps[i+1:]...) + return + } + } +} + type referencedNetworkEndpoint struct { ep NetworkEndpoint nic *NIC |