summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rwxr-xr-xpkg/sentry/kernel/seqatomic_taskgoroutineschedinfo_unsafe.go2
-rwxr-xr-xpkg/sentry/platform/ring0/defs_impl.go2
-rwxr-xr-xpkg/sentry/time/seqatomic_parameters_unsafe.go2
-rw-r--r--pkg/tcpip/stack/nic.go123
-rw-r--r--pkg/tcpip/tcpip.go5
5 files changed, 60 insertions, 74 deletions
diff --git a/pkg/sentry/kernel/seqatomic_taskgoroutineschedinfo_unsafe.go b/pkg/sentry/kernel/seqatomic_taskgoroutineschedinfo_unsafe.go
index 25ad17a4e..c284a1b11 100755
--- a/pkg/sentry/kernel/seqatomic_taskgoroutineschedinfo_unsafe.go
+++ b/pkg/sentry/kernel/seqatomic_taskgoroutineschedinfo_unsafe.go
@@ -1,11 +1,11 @@
package kernel
import (
- "fmt"
"reflect"
"strings"
"unsafe"
+ "fmt"
"gvisor.dev/gvisor/third_party/gvsync"
)
diff --git a/pkg/sentry/platform/ring0/defs_impl.go b/pkg/sentry/platform/ring0/defs_impl.go
index a30a9dd4a..acae012dc 100755
--- a/pkg/sentry/platform/ring0/defs_impl.go
+++ b/pkg/sentry/platform/ring0/defs_impl.go
@@ -1,12 +1,12 @@
package ring0
import (
+ "fmt"
"gvisor.dev/gvisor/pkg/cpuid"
"io"
"reflect"
"syscall"
- "fmt"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/usermem"
)
diff --git a/pkg/sentry/time/seqatomic_parameters_unsafe.go b/pkg/sentry/time/seqatomic_parameters_unsafe.go
index 89792c56d..1ec221edd 100755
--- a/pkg/sentry/time/seqatomic_parameters_unsafe.go
+++ b/pkg/sentry/time/seqatomic_parameters_unsafe.go
@@ -1,11 +1,11 @@
package time
import (
- "fmt"
"reflect"
"strings"
"unsafe"
+ "fmt"
"gvisor.dev/gvisor/third_party/gvsync"
)
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index dc28dc970..4ef85bdfb 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -186,41 +186,73 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN
return nil
}
+func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
+ return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous)
+}
+
// findEndpoint finds the endpoint, if any, with the given address.
func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
+ return n.getRefOrCreateTemp(protocol, address, peb, n.spoofing)
+}
+
+// getRefEpOrCreateTemp returns the referenced network endpoint for the given
+// protocol and address. If none exists a temporary one may be created if
+// requested.
+func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, allowTemp bool) *referencedNetworkEndpoint {
id := NetworkEndpointID{address}
n.mu.RLock()
- ref := n.endpoints[id]
- if ref != nil && !ref.tryIncRef() {
- ref = nil
+
+ if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
+ n.mu.RUnlock()
+ return ref
+ }
+
+ // The address was not found, create a temporary one if requested by the
+ // caller or if the address is found in the NIC's subnets.
+ createTempEP := allowTemp
+ if !createTempEP {
+ for _, sn := range n.subnets {
+ if sn.Contains(address) {
+ createTempEP = true
+ break
+ }
+ }
}
- spoofing := n.spoofing
+
n.mu.RUnlock()
- if ref != nil || !spoofing {
- return ref
+ if !createTempEP {
+ return nil
}
// Try again with the lock in exclusive mode. If we still can't get the
// endpoint, create a new "temporary" endpoint. It will only exist while
// there's a route through it.
n.mu.Lock()
- ref = n.endpoints[id]
- if ref == nil || !ref.tryIncRef() {
- if netProto, ok := n.stack.networkProtocols[protocol]; ok {
- ref, _ = n.addAddressLocked(tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: address,
- PrefixLen: netProto.DefaultPrefixLen(),
- },
- }, peb, true)
- if ref != nil {
- ref.holdsInsertRef = false
- }
- }
+ if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
+ n.mu.Unlock()
+ return ref
+ }
+
+ netProto, ok := n.stack.networkProtocols[protocol]
+ if !ok {
+ n.mu.Unlock()
+ return nil
}
+
+ ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{
+ Protocol: protocol,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: netProto.DefaultPrefixLen(),
+ },
+ }, peb, true)
+
+ if ref != nil {
+ ref.holdsInsertRef = false
+ }
+
n.mu.Unlock()
return ref
}
@@ -553,57 +585,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
n.stack.stats.IP.InvalidAddressesReceived.Increment()
}
-func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
- id := NetworkEndpointID{dst}
-
- n.mu.RLock()
- if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
- n.mu.RUnlock()
- return ref
- }
-
- promiscuous := n.promiscuous
- // Check if the packet is for a subnet this NIC cares about.
- if !promiscuous {
- for _, sn := range n.subnets {
- if sn.Contains(dst) {
- promiscuous = true
- break
- }
- }
- }
- n.mu.RUnlock()
- if promiscuous {
- // Try again with the lock in exclusive mode. If we still can't
- // get the endpoint, create a new "temporary" one. It will only
- // exist while there's a route through it.
- n.mu.Lock()
- if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
- n.mu.Unlock()
- return ref
- }
- netProto, ok := n.stack.networkProtocols[protocol]
- if !ok {
- n.mu.Unlock()
- return nil
- }
- ref, err := n.addAddressLocked(tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: dst,
- PrefixLen: netProto.DefaultPrefixLen(),
- },
- }, CanBePrimaryEndpoint, true)
- n.mu.Unlock()
- if err == nil {
- ref.holdsInsertRef = false
- return ref
- }
- }
-
- return nil
-}
-
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) {
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index f664673b3..ba2dd85b8 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -168,6 +168,11 @@ func NewSubnet(a Address, m AddressMask) (Subnet, error) {
return Subnet{a, m}, nil
}
+// String implements Stringer.
+func (s Subnet) String() string {
+ return fmt.Sprintf("%s/%d", s.ID(), s.Prefix())
+}
+
// Contains returns true iff the address is of the same length and matches the
// subnet address and mask.
func (s *Subnet) Contains(a Address) bool {