summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack/nic.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack/nic.go')
-rw-r--r--pkg/tcpip/stack/nic.go453
1 files changed, 453 insertions, 0 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
new file mode 100644
index 000000000..8ff4310d5
--- /dev/null
+++ b/pkg/tcpip/stack/nic.go
@@ -0,0 +1,453 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package stack
+
+import (
+ "strings"
+ "sync"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/ilist"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+// NIC represents a "network interface card" to which the networking stack is
+// attached.
+type NIC struct {
+ stack *Stack
+ id tcpip.NICID
+ name string
+ linkEP LinkEndpoint
+
+ demux *transportDemuxer
+
+ mu sync.RWMutex
+ spoofing bool
+ promiscuous bool
+ primary map[tcpip.NetworkProtocolNumber]*ilist.List
+ endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
+ subnets []tcpip.Subnet
+}
+
+func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC {
+ return &NIC{
+ stack: stack,
+ id: id,
+ name: name,
+ linkEP: ep,
+ demux: newTransportDemuxer(stack),
+ primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
+ endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
+ }
+}
+
+// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
+// to start delivering packets.
+func (n *NIC) attachLinkEndpoint() {
+ n.linkEP.Attach(n)
+}
+
+// setPromiscuousMode enables or disables promiscuous mode.
+func (n *NIC) setPromiscuousMode(enable bool) {
+ n.mu.Lock()
+ n.promiscuous = enable
+ n.mu.Unlock()
+}
+
+// setSpoofing enables or disables address spoofing.
+func (n *NIC) setSpoofing(enable bool) {
+ n.mu.Lock()
+ n.spoofing = enable
+ n.mu.Unlock()
+}
+
+// primaryEndpoint returns the primary endpoint of n for the given network
+// protocol.
+func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ list := n.primary[protocol]
+ if list == nil {
+ return nil
+ }
+
+ for e := list.Front(); e != nil; e = e.Next() {
+ r := e.(*referencedNetworkEndpoint)
+ if r.tryIncRef() {
+ return r
+ }
+ }
+
+ return nil
+}
+
+// findEndpoint finds the endpoint, if any, with the given address.
+func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address) *referencedNetworkEndpoint {
+ id := NetworkEndpointID{address}
+
+ n.mu.RLock()
+ ref := n.endpoints[id]
+ if ref != nil && !ref.tryIncRef() {
+ ref = nil
+ }
+ spoofing := n.spoofing
+ n.mu.RUnlock()
+
+ if ref != nil || !spoofing {
+ return ref
+ }
+
+ // Try again with the lock in exclusive mode. If we still can't get the
+ // endpoint, create a new "temporary" endpoint. It will only exist while
+ // there's a route through it.
+ n.mu.Lock()
+ ref = n.endpoints[id]
+ if ref == nil || !ref.tryIncRef() {
+ ref, _ = n.addAddressLocked(protocol, address, true)
+ if ref != nil {
+ ref.holdsInsertRef = false
+ }
+ }
+ n.mu.Unlock()
+ return ref
+}
+
+func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) {
+ netProto, ok := n.stack.networkProtocols[protocol]
+ if !ok {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ // Create the new network endpoint.
+ ep, err := netProto.NewEndpoint(n.id, addr, n.stack, n, n.linkEP)
+ if err != nil {
+ return nil, err
+ }
+
+ id := *ep.ID()
+ if ref, ok := n.endpoints[id]; ok {
+ if !replace {
+ return nil, tcpip.ErrDuplicateAddress
+ }
+
+ n.removeEndpointLocked(ref)
+ }
+
+ ref := &referencedNetworkEndpoint{
+ refs: 1,
+ ep: ep,
+ nic: n,
+ protocol: protocol,
+ holdsInsertRef: true,
+ }
+
+ // Set up cache if link address resolution exists for this protocol.
+ if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
+ if linkRes := n.stack.linkAddrResolvers[protocol]; linkRes != nil {
+ ref.linkCache = n.stack
+ }
+ }
+
+ n.endpoints[id] = ref
+
+ l, ok := n.primary[protocol]
+ if !ok {
+ l = &ilist.List{}
+ n.primary[protocol] = l
+ }
+
+ l.PushBack(ref)
+
+ return ref, nil
+}
+
+// AddAddress adds a new address to n, so that it starts accepting packets
+// targeted at the given address (and network protocol).
+func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ // Add the endpoint.
+ n.mu.Lock()
+ _, err := n.addAddressLocked(protocol, addr, false)
+ n.mu.Unlock()
+
+ return err
+}
+
+// Addresses returns the addresses associated with this NIC.
+func (n *NIC) Addresses() []tcpip.ProtocolAddress {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+ addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
+ for nid, ep := range n.endpoints {
+ addrs = append(addrs, tcpip.ProtocolAddress{
+ Protocol: ep.protocol,
+ Address: nid.LocalAddress,
+ })
+ }
+ return addrs
+}
+
+// AddSubnet adds a new subnet to n, so that it starts accepting packets
+// targeted at the given address and network protocol.
+func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) {
+ n.mu.Lock()
+ n.subnets = append(n.subnets, subnet)
+ n.mu.Unlock()
+}
+
+// Subnets returns the Subnets associated with this NIC.
+func (n *NIC) Subnets() []tcpip.Subnet {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+ sns := make([]tcpip.Subnet, 0, len(n.subnets)+len(n.endpoints))
+ for nid := range n.endpoints {
+ sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress))))
+ if err != nil {
+ // This should never happen as the mask has been carefully crafted to
+ // match the address.
+ panic("Invalid endpoint subnet: " + err.Error())
+ }
+ sns = append(sns, sn)
+ }
+ return append(sns, n.subnets...)
+}
+
+func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
+ id := *r.ep.ID()
+
+ // Nothing to do if the reference has already been replaced with a
+ // different one.
+ if n.endpoints[id] != r {
+ return
+ }
+
+ if r.holdsInsertRef {
+ panic("Reference count dropped to zero before being removed")
+ }
+
+ delete(n.endpoints, id)
+ n.primary[r.protocol].Remove(r)
+ r.ep.Close()
+}
+
+func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
+ n.mu.Lock()
+ n.removeEndpointLocked(r)
+ n.mu.Unlock()
+}
+
+// RemoveAddress removes an address from n.
+func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ r := n.endpoints[NetworkEndpointID{addr}]
+ if r == nil || !r.holdsInsertRef {
+ n.mu.Unlock()
+ return tcpip.ErrBadLocalAddress
+ }
+
+ r.holdsInsertRef = false
+ n.mu.Unlock()
+
+ r.decRef()
+
+ return nil
+}
+
+// DeliverNetworkPacket finds the appropriate network protocol endpoint and
+// hands the packet over for further processing. This function is called when
+// the NIC receives a packet from the physical interface.
+// Note that the ownership of the slice backing vv is retained by the caller.
+// This rule applies only to the slice itself, not to the items of the slice;
+// the ownership of the items is not retained by the caller.
+func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) {
+ netProto, ok := n.stack.networkProtocols[protocol]
+ if !ok {
+ atomic.AddUint64(&n.stack.stats.UnknownProtocolRcvdPackets, 1)
+ return
+ }
+
+ if len(vv.First()) < netProto.MinimumPacketSize() {
+ atomic.AddUint64(&n.stack.stats.MalformedRcvdPackets, 1)
+ return
+ }
+
+ src, dst := netProto.ParseAddresses(vv.First())
+ id := NetworkEndpointID{dst}
+
+ n.mu.RLock()
+ ref := n.endpoints[id]
+ if ref != nil && !ref.tryIncRef() {
+ ref = nil
+ }
+ promiscuous := n.promiscuous
+ subnets := n.subnets
+ n.mu.RUnlock()
+
+ if ref == nil {
+ // Check if the packet is for a subnet this NIC cares about.
+ if !promiscuous {
+ for _, sn := range subnets {
+ if sn.Contains(dst) {
+ promiscuous = true
+ break
+ }
+ }
+ }
+ if promiscuous {
+ // Try again with the lock in exclusive mode. If we still can't
+ // get the endpoint, create a new "temporary" one. It will only
+ // exist while there's a route through it.
+ n.mu.Lock()
+ ref = n.endpoints[id]
+ if ref == nil || !ref.tryIncRef() {
+ ref, _ = n.addAddressLocked(protocol, dst, true)
+ if ref != nil {
+ ref.holdsInsertRef = false
+ }
+ }
+ n.mu.Unlock()
+ }
+ }
+
+ if ref == nil {
+ atomic.AddUint64(&n.stack.stats.UnknownNetworkEndpointRcvdPackets, 1)
+ return
+ }
+
+ r := makeRoute(protocol, dst, src, ref)
+ r.LocalLinkAddress = linkEP.LinkAddress()
+ r.RemoteLinkAddress = remoteLinkAddr
+ ref.ep.HandlePacket(&r, vv)
+ ref.decRef()
+}
+
+// DeliverTransportPacket delivers the packets to the appropriate transport
+// protocol endpoint.
+func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView) {
+ state, ok := n.stack.transportProtocols[protocol]
+ if !ok {
+ atomic.AddUint64(&n.stack.stats.UnknownProtocolRcvdPackets, 1)
+ return
+ }
+
+ transProto := state.proto
+ if len(vv.First()) < transProto.MinimumPacketSize() {
+ atomic.AddUint64(&n.stack.stats.MalformedRcvdPackets, 1)
+ return
+ }
+
+ srcPort, dstPort, err := transProto.ParsePorts(vv.First())
+ if err != nil {
+ atomic.AddUint64(&n.stack.stats.MalformedRcvdPackets, 1)
+ return
+ }
+
+ id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
+ if n.demux.deliverPacket(r, protocol, vv, id) {
+ return
+ }
+ if n.stack.demux.deliverPacket(r, protocol, vv, id) {
+ return
+ }
+
+ // Try to deliver to per-stack default handler.
+ if state.defaultHandler != nil {
+ if state.defaultHandler(r, id, vv) {
+ return
+ }
+ }
+
+ // We could not find an appropriate destination for this packet, so
+ // deliver it to the global handler.
+ if !transProto.HandleUnknownDestinationPacket(r, id, vv) {
+ atomic.AddUint64(&n.stack.stats.MalformedRcvdPackets, 1)
+ }
+}
+
+// DeliverTransportControlPacket delivers control packets to the appropriate
+// transport protocol endpoint.
+func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv *buffer.VectorisedView) {
+ state, ok := n.stack.transportProtocols[trans]
+ if !ok {
+ return
+ }
+
+ transProto := state.proto
+
+ // ICMPv4 only guarantees that 8 bytes of the transport protocol will
+ // be present in the payload. We know that the ports are within the
+ // first 8 bytes for all known transport protocols.
+ if len(vv.First()) < 8 {
+ return
+ }
+
+ srcPort, dstPort, err := transProto.ParsePorts(vv.First())
+ if err != nil {
+ return
+ }
+
+ id := TransportEndpointID{srcPort, local, dstPort, remote}
+ if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
+ return
+ }
+ if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
+ return
+ }
+}
+
+// ID returns the identifier of n.
+func (n *NIC) ID() tcpip.NICID {
+ return n.id
+}
+
+type referencedNetworkEndpoint struct {
+ ilist.Entry
+ refs int32
+ ep NetworkEndpoint
+ nic *NIC
+ protocol tcpip.NetworkProtocolNumber
+
+ // linkCache is set if link address resolution is enabled for this
+ // protocol. Set to nil otherwise.
+ linkCache LinkAddressCache
+
+ // holdsInsertRef is protected by the NIC's mutex. It indicates whether
+ // the reference count is biased by 1 due to the insertion of the
+ // endpoint. It is reset to false when RemoveAddress is called on the
+ // NIC.
+ holdsInsertRef bool
+}
+
+// decRef decrements the ref count and cleans up the endpoint once it reaches
+// zero.
+func (r *referencedNetworkEndpoint) decRef() {
+ if atomic.AddInt32(&r.refs, -1) == 0 {
+ r.nic.removeEndpoint(r)
+ }
+}
+
+// incRef increments the ref count. It must only be called when the caller is
+// known to be holding a reference to the endpoint, otherwise tryIncRef should
+// be used.
+func (r *referencedNetworkEndpoint) incRef() {
+ atomic.AddInt32(&r.refs, 1)
+}
+
+// tryIncRef attempts to increment the ref count from n to n+1, but only if n is
+// not zero. That is, it will increment the count if the endpoint is still
+// alive, and do nothing if it has already been clean up.
+func (r *referencedNetworkEndpoint) tryIncRef() bool {
+ for {
+ v := atomic.LoadInt32(&r.refs)
+ if v == 0 {
+ return false
+ }
+
+ if atomic.CompareAndSwapInt32(&r.refs, v, v+1) {
+ return true
+ }
+ }
+}