diff options
-rwxr-xr-x | pkg/sentry/kernel/seqatomic_taskgoroutineschedinfo_unsafe.go | 2 | ||||
-rwxr-xr-x | pkg/sentry/platform/ring0/defs_impl.go | 2 | ||||
-rwxr-xr-x | pkg/sentry/time/seqatomic_parameters_unsafe.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 123 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 5 |
5 files changed, 60 insertions, 74 deletions
diff --git a/pkg/sentry/kernel/seqatomic_taskgoroutineschedinfo_unsafe.go b/pkg/sentry/kernel/seqatomic_taskgoroutineschedinfo_unsafe.go index 25ad17a4e..c284a1b11 100755 --- a/pkg/sentry/kernel/seqatomic_taskgoroutineschedinfo_unsafe.go +++ b/pkg/sentry/kernel/seqatomic_taskgoroutineschedinfo_unsafe.go @@ -1,11 +1,11 @@ package kernel import ( - "fmt" "reflect" "strings" "unsafe" + "fmt" "gvisor.dev/gvisor/third_party/gvsync" ) diff --git a/pkg/sentry/platform/ring0/defs_impl.go b/pkg/sentry/platform/ring0/defs_impl.go index a30a9dd4a..acae012dc 100755 --- a/pkg/sentry/platform/ring0/defs_impl.go +++ b/pkg/sentry/platform/ring0/defs_impl.go @@ -1,12 +1,12 @@ package ring0 import ( + "fmt" "gvisor.dev/gvisor/pkg/cpuid" "io" "reflect" "syscall" - "fmt" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/usermem" ) diff --git a/pkg/sentry/time/seqatomic_parameters_unsafe.go b/pkg/sentry/time/seqatomic_parameters_unsafe.go index 89792c56d..1ec221edd 100755 --- a/pkg/sentry/time/seqatomic_parameters_unsafe.go +++ b/pkg/sentry/time/seqatomic_parameters_unsafe.go @@ -1,11 +1,11 @@ package time import ( - "fmt" "reflect" "strings" "unsafe" + "fmt" "gvisor.dev/gvisor/third_party/gvsync" ) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index dc28dc970..4ef85bdfb 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -186,41 +186,73 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN return nil } +func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { + return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous) +} + // findEndpoint finds the endpoint, if any, with the given address. func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { + return n.getRefOrCreateTemp(protocol, address, peb, n.spoofing) +} + +// getRefEpOrCreateTemp returns the referenced network endpoint for the given +// protocol and address. If none exists a temporary one may be created if +// requested. +func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, allowTemp bool) *referencedNetworkEndpoint { id := NetworkEndpointID{address} n.mu.RLock() - ref := n.endpoints[id] - if ref != nil && !ref.tryIncRef() { - ref = nil + + if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { + n.mu.RUnlock() + return ref + } + + // The address was not found, create a temporary one if requested by the + // caller or if the address is found in the NIC's subnets. + createTempEP := allowTemp + if !createTempEP { + for _, sn := range n.subnets { + if sn.Contains(address) { + createTempEP = true + break + } + } } - spoofing := n.spoofing + n.mu.RUnlock() - if ref != nil || !spoofing { - return ref + if !createTempEP { + return nil } // Try again with the lock in exclusive mode. If we still can't get the // endpoint, create a new "temporary" endpoint. It will only exist while // there's a route through it. n.mu.Lock() - ref = n.endpoints[id] - if ref == nil || !ref.tryIncRef() { - if netProto, ok := n.stack.networkProtocols[protocol]; ok { - ref, _ = n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: address, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, peb, true) - if ref != nil { - ref.holdsInsertRef = false - } - } + if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { + n.mu.Unlock() + return ref + } + + netProto, ok := n.stack.networkProtocols[protocol] + if !ok { + n.mu.Unlock() + return nil } + + ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{ + Protocol: protocol, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: netProto.DefaultPrefixLen(), + }, + }, peb, true) + + if ref != nil { + ref.holdsInsertRef = false + } + n.mu.Unlock() return ref } @@ -553,57 +585,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr n.stack.stats.IP.InvalidAddressesReceived.Increment() } -func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { - id := NetworkEndpointID{dst} - - n.mu.RLock() - if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - n.mu.RUnlock() - return ref - } - - promiscuous := n.promiscuous - // Check if the packet is for a subnet this NIC cares about. - if !promiscuous { - for _, sn := range n.subnets { - if sn.Contains(dst) { - promiscuous = true - break - } - } - } - n.mu.RUnlock() - if promiscuous { - // Try again with the lock in exclusive mode. If we still can't - // get the endpoint, create a new "temporary" one. It will only - // exist while there's a route through it. - n.mu.Lock() - if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - n.mu.Unlock() - return ref - } - netProto, ok := n.stack.networkProtocols[protocol] - if !ok { - n.mu.Unlock() - return nil - } - ref, err := n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: dst, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, CanBePrimaryEndpoint, true) - n.mu.Unlock() - if err == nil { - ref.holdsInsertRef = false - return ref - } - } - - return nil -} - // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index f664673b3..ba2dd85b8 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -168,6 +168,11 @@ func NewSubnet(a Address, m AddressMask) (Subnet, error) { return Subnet{a, m}, nil } +// String implements Stringer. +func (s Subnet) String() string { + return fmt.Sprintf("%s/%d", s.ID(), s.Prefix()) +} + // Contains returns true iff the address is of the same length and matches the // subnet address and mask. func (s *Subnet) Contains(a Address) bool { |