summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/abi/linux/linux_abi_autogen_unsafe.go22
-rw-r--r--pkg/sentry/socket/netstack/netstack.go15
-rw-r--r--pkg/tcpip/ports/ports.go19
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go60
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go2
5 files changed, 72 insertions, 46 deletions
diff --git a/pkg/abi/linux/linux_abi_autogen_unsafe.go b/pkg/abi/linux/linux_abi_autogen_unsafe.go
index 840b562db..a53fc398d 100644
--- a/pkg/abi/linux/linux_abi_autogen_unsafe.go
+++ b/pkg/abi/linux/linux_abi_autogen_unsafe.go
@@ -152,7 +152,7 @@ func (s *Statx) UnmarshalBytes(src []byte) {
// Packed implements marshal.Marshallable.Packed.
//go:nosplit
func (s *Statx) Packed() bool {
- return s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed()
+ return s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed()
}
// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
@@ -167,7 +167,7 @@ func (s *Statx) MarshalUnsafe(dst []byte) {
// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
func (s *Statx) UnmarshalUnsafe(src []byte) {
- if s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() {
+ if s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() {
safecopy.CopyOut(unsafe.Pointer(s), src)
} else {
// Type Statx doesn't have a packed layout in memory, fallback to UnmarshalBytes.
@@ -178,7 +178,7 @@ func (s *Statx) UnmarshalUnsafe(src []byte) {
// CopyOutN implements marshal.Marshallable.CopyOutN.
//go:nosplit
func (s *Statx) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
- if !s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() {
+ if !s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed() {
// Type Statx doesn't have a packed layout in memory, fall back to MarshalBytes.
buf := task.CopyScratchBuffer(s.SizeBytes()) // escapes: okay.
s.MarshalBytes(buf) // escapes: fallback.
@@ -208,7 +208,7 @@ func (s *Statx) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
// CopyIn implements marshal.Marshallable.CopyIn.
//go:nosplit
func (s *Statx) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
- if !s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed() {
+ if !s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() {
// Type Statx doesn't have a packed layout in memory, fall back to UnmarshalBytes.
buf := task.CopyScratchBuffer(s.SizeBytes()) // escapes: okay.
length, err := task.CopyInBytes(addr, buf) // escapes: okay.
@@ -642,7 +642,7 @@ func (f *FUSEHeaderIn) MarshalUnsafe(dst []byte) {
// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
func (f *FUSEHeaderIn) UnmarshalUnsafe(src []byte) {
- if f.Unique.Packed() && f.Opcode.Packed() {
+ if f.Opcode.Packed() && f.Unique.Packed() {
safecopy.CopyOut(unsafe.Pointer(f), src)
} else {
// Type FUSEHeaderIn doesn't have a packed layout in memory, fallback to UnmarshalBytes.
@@ -2035,7 +2035,7 @@ func (i *IPTEntry) MarshalUnsafe(dst []byte) {
// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
func (i *IPTEntry) UnmarshalUnsafe(src []byte) {
- if i.Counters.Packed() && i.IP.Packed() {
+ if i.IP.Packed() && i.Counters.Packed() {
safecopy.CopyOut(unsafe.Pointer(i), src)
} else {
// Type IPTEntry doesn't have a packed layout in memory, fallback to UnmarshalBytes.
@@ -2208,12 +2208,12 @@ func (i *IPTIP) UnmarshalBytes(src []byte) {
// Packed implements marshal.Marshallable.Packed.
//go:nosplit
func (i *IPTIP) Packed() bool {
- return i.SrcMask.Packed() && i.DstMask.Packed() && i.Src.Packed() && i.Dst.Packed()
+ return i.Src.Packed() && i.Dst.Packed() && i.SrcMask.Packed() && i.DstMask.Packed()
}
// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
func (i *IPTIP) MarshalUnsafe(dst []byte) {
- if i.SrcMask.Packed() && i.DstMask.Packed() && i.Src.Packed() && i.Dst.Packed() {
+ if i.Src.Packed() && i.Dst.Packed() && i.SrcMask.Packed() && i.DstMask.Packed() {
safecopy.CopyIn(dst, unsafe.Pointer(i))
} else {
// Type IPTIP doesn't have a packed layout in memory, fallback to MarshalBytes.
@@ -2234,7 +2234,7 @@ func (i *IPTIP) UnmarshalUnsafe(src []byte) {
// CopyOutN implements marshal.Marshallable.CopyOutN.
//go:nosplit
func (i *IPTIP) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
- if !i.Src.Packed() && i.Dst.Packed() && i.SrcMask.Packed() && i.DstMask.Packed() {
+ if !i.Dst.Packed() && i.SrcMask.Packed() && i.DstMask.Packed() && i.Src.Packed() {
// Type IPTIP doesn't have a packed layout in memory, fall back to MarshalBytes.
buf := task.CopyScratchBuffer(i.SizeBytes()) // escapes: okay.
i.MarshalBytes(buf) // escapes: fallback.
@@ -2290,7 +2290,7 @@ func (i *IPTIP) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
// WriteTo implements io.WriterTo.WriteTo.
func (i *IPTIP) WriteTo(w io.Writer) (int64, error) {
- if !i.DstMask.Packed() && i.Src.Packed() && i.Dst.Packed() && i.SrcMask.Packed() {
+ if !i.Src.Packed() && i.Dst.Packed() && i.SrcMask.Packed() && i.DstMask.Packed() {
// Type IPTIP doesn't have a packed layout in memory, fall back to MarshalBytes.
buf := make([]byte, i.SizeBytes())
i.MarshalBytes(buf)
@@ -3201,7 +3201,7 @@ func (i *IP6TIP) Packed() bool {
// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
func (i *IP6TIP) MarshalUnsafe(dst []byte) {
- if i.DstMask.Packed() && i.Src.Packed() && i.Dst.Packed() && i.SrcMask.Packed() {
+ if i.Src.Packed() && i.Dst.Packed() && i.SrcMask.Packed() && i.DstMask.Packed() {
safecopy.CopyIn(dst, unsafe.Pointer(i))
} else {
// Type IP6TIP doesn't have a packed layout in memory, fallback to MarshalBytes.
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 0e5913b60..4d0e33696 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -803,7 +803,20 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
// Issue the bind request to the endpoint.
- return syserr.TranslateNetstackError(s.Endpoint.Bind(addr))
+ err := s.Endpoint.Bind(addr)
+ if err == tcpip.ErrNoPortAvailable {
+ // Bind always returns EADDRINUSE irrespective of if the specified port was
+ // already bound or if an ephemeral port was requested but none were
+ // available.
+ //
+ // tcpip.ErrNoPortAvailable is mapped to EAGAIN in syserr package because
+ // UDP connect returns EAGAIN on ephemeral port exhaustion.
+ //
+ // TCP connect returns EADDRNOTAVAIL on ephemeral port exhaustion.
+ err = tcpip.ErrPortInUse
+ }
+
+ return syserr.TranslateNetstackError(err)
}
// Listen implements the linux syscall listen(2) for sockets backed by
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index f6d592eb5..d87193650 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -400,7 +400,11 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb
// reserved by another endpoint. If port is zero, ReservePort will search for
// an unreserved ephemeral port and reserve it, returning its value in the
// "port" return value.
-func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) (reservedPort uint16, err *tcpip.Error) {
+//
+// An optional testPort closure can be passed in which if provided will be used
+// to test if the picked port can be used. The function should return true if
+// the port is safe to use, false otherwise.
+func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress, testPort func(port uint16) bool) (reservedPort uint16, err *tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -412,12 +416,23 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp
if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) {
return 0, tcpip.ErrPortInUse
}
+ if testPort != nil && !testPort(port) {
+ s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, dst)
+ return 0, tcpip.ErrPortInUse
+ }
return port, nil
}
// A port wasn't specified, so try to find one.
return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
- return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst), nil
+ if !s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst) {
+ return false, nil
+ }
+ if testPort != nil && !testPort(p) {
+ s.releasePortLocked(networks, transport, addr, p, flags.Bits(), bindToDevice, dst)
+ return false, nil
+ }
+ return true, nil
})
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 21a4b6e2f..9df22ac84 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -2169,7 +2169,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
if sameAddr && p == e.ID.RemotePort {
return false, nil
}
- if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil {
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil {
if err != tcpip.ErrPortInUse || !reuse {
return false, nil
}
@@ -2207,7 +2207,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
tcpEP.notifyProtocolGoroutine(notifyAbort)
tcpEP.UnlockUser()
// Now try and Reserve again if it fails then we skip.
- if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil {
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil {
return false, nil
}
}
@@ -2505,47 +2505,45 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
}
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{})
- if err != nil {
- return err
- }
-
- e.boundBindToDevice = e.bindToDevice
- e.boundPortFlags = e.portFlags
- e.isPortReserved = true
- e.effectiveNetProtos = netProtos
- e.ID.LocalPort = port
-
- // Any failures beyond this point must remove the port registration.
- defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) {
- if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice, tcpip.FullAddress{})
- e.isPortReserved = false
- e.effectiveNetProtos = nil
- e.ID.LocalPort = 0
- e.ID.LocalAddress = ""
- e.boundNICID = 0
- e.boundBindToDevice = 0
- e.boundPortFlags = ports.Flags{}
- }
- }(e.boundPortFlags, e.boundBindToDevice)
-
+ var nic tcpip.NICID
// If an address is specified, we must ensure that it's one of our
// local addresses.
if len(addr.Addr) != 0 {
- nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
+ nic = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
if nic == 0 {
return tcpip.ErrBadLocalAddress
}
-
- e.boundNICID = nic
e.ID.LocalAddress = addr.Addr
}
- if err := e.stack.CheckRegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e.boundPortFlags, e.boundBindToDevice); err != nil {
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, func(p uint16) bool {
+ id := e.ID
+ id.LocalPort = p
+ // CheckRegisterTransportEndpoint should only return an error if there is a
+ // listening endpoint bound with the same id and portFlags and bindToDevice
+ // options.
+ //
+ // NOTE: Only listening and connected endpoint register with
+ // demuxer. Further connected endpoints always have a remote
+ // address/port. Hence this will only return an error if there is a matching
+ // listening endpoint.
+ if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, e.bindToDevice); err != nil {
+ return false
+ }
+ return true
+ })
+ if err != nil {
return err
}
+ e.boundBindToDevice = e.bindToDevice
+ e.boundPortFlags = e.portFlags
+ // TODO(gvisor.dev/issue/3691): Add test to verify boundNICID is correct.
+ e.boundNICID = nic
+ e.isPortReserved = true
+ e.effectiveNetProtos = netProtos
+ e.ID.LocalPort = port
+
// Mark endpoint as bound.
e.setEndpointState(StateBound)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 73608783c..c33434b75 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -1226,7 +1226,7 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) {
if e.ID.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{})
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, nil /* testPort */)
if err != nil {
return id, e.bindToDevice, err
}