summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go45
-rw-r--r--pkg/sentry/socket/unix/unix.go2
-rw-r--r--pkg/sentry/strace/socket.go2
-rw-r--r--pkg/tcpip/stack/transport_test.go5
-rw-r--r--pkg/tcpip/tcpip.go3
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go10
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go10
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go10
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go6
9 files changed, 55 insertions, 38 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index 4d8a5ac22..635042263 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -291,18 +291,22 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
return tcpip.Address(addr)
}
-// GetAddress reads an sockaddr struct from the given address and converts it
-// to the FullAddress format. It supports AF_UNIX, AF_INET and AF_INET6
-// addresses.
-func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syserr.Error) {
+// AddressAndFamily reads an sockaddr struct from the given address and
+// converts it to the FullAddress format. It supports AF_UNIX, AF_INET and
+// AF_INET6 addresses.
+//
+// strict indicates whether addresses with the AF_UNSPEC family are accepted of not.
+//
+// AddressAndFamily returns an address, its family.
+func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) {
// Make sure we have at least 2 bytes for the address family.
if len(addr) < 2 {
- return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument
}
family := usermem.ByteOrder.Uint16(addr)
if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) {
- return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
+ return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported
}
// Get the rest of the fields based on the address family.
@@ -310,7 +314,7 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse
case linux.AF_UNIX:
path := addr[2:]
if len(path) > linux.UnixPathMax {
- return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
// Drop the terminating NUL (if one exists) and everything after
// it for filesystem (non-abstract) addresses.
@@ -321,12 +325,12 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse
}
return tcpip.FullAddress{
Addr: tcpip.Address(path),
- }, nil
+ }, family, nil
case linux.AF_INET:
var a linux.SockAddrInet
if len(addr) < sockAddrInetSize {
- return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
@@ -334,12 +338,12 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse
Addr: bytesToIPAddress(a.Addr[:]),
Port: ntohs(a.Port),
}
- return out, nil
+ return out, family, nil
case linux.AF_INET6:
var a linux.SockAddrInet6
if len(addr) < sockAddrInet6Size {
- return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
@@ -350,13 +354,13 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse
if isLinkLocal(out.Addr) {
out.NIC = tcpip.NICID(a.Scope_id)
}
- return out, nil
+ return out, family, nil
case linux.AF_UNSPEC:
- return tcpip.FullAddress{}, nil
+ return tcpip.FullAddress{}, family, nil
default:
- return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
+ return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported
}
}
@@ -482,11 +486,18 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
// Connect implements the linux syscall connect(2) for sockets backed by
// tpcip.Endpoint.
func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
- addr, err := GetAddress(s.family, sockaddr, false /* strict */)
+ addr, family, err := AddressAndFamily(s.family, sockaddr, false /* strict */)
if err != nil {
return err
}
+ if family == linux.AF_UNSPEC {
+ err := s.Endpoint.Disconnect()
+ if err == tcpip.ErrNotSupported {
+ return syserr.ErrAddressFamilyNotSupported
+ }
+ return syserr.TranslateNetstackError(err)
+ }
// Always return right away in the non-blocking case.
if !blocking {
return syserr.TranslateNetstackError(s.Endpoint.Connect(addr))
@@ -515,7 +526,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
- addr, err := GetAddress(s.family, sockaddr, true /* strict */)
+ addr, _, err := AddressAndFamily(s.family, sockaddr, true /* strict */)
if err != nil {
return err
}
@@ -2023,7 +2034,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
var addr *tcpip.FullAddress
if len(to) > 0 {
- addrBuf, err := GetAddress(s.family, to, true /* strict */)
+ addrBuf, _, err := AddressAndFamily(s.family, to, true /* strict */)
if err != nil {
return 0, err
}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 8a3f65236..0d0cb68df 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -116,7 +116,7 @@ func (s *SocketOperations) Endpoint() transport.Endpoint {
// extractPath extracts and validates the address.
func extractPath(sockaddr []byte) (string, *syserr.Error) {
- addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr, true /* strict */)
+ addr, _, err := epsocket.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */)
if err != nil {
return "", err
}
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index 386b40af7..f779186ad 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -332,7 +332,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string {
switch family {
case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX:
- fa, err := epsocket.GetAddress(int(family), b, true /* strict */)
+ fa, _, err := epsocket.AddressAndFamily(int(family), b, true /* strict */)
if err != nil {
return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err)
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 8146e8444..1c811ab68 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -105,6 +105,11 @@ func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*fakeTransportEndpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
f.peerAddr = addr.Addr
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index ba2dd85b8..29a6025d9 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -358,6 +358,9 @@ type Endpoint interface {
// ErrAddressFamilyNotSupported must be returned.
Connect(address FullAddress) *Error
+ // Disconnect disconnects the endpoint from its peer.
+ Disconnect() *Error
+
// Shutdown closes the read and/or write end of the endpoint connection
// to its peer.
Shutdown(flags ShutdownFlags) *Error
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 2e8d5d4bf..451d3880e 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -428,16 +428,16 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
return netProto, nil
}
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if addr.Addr == "" {
- // AF_UNSPEC isn't supported.
- return tcpip.ErrAddressFamilyNotSupported
- }
-
nicid := addr.NIC
localPort := uint16(0)
switch e.state {
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 53c9515a4..13e17e2a6 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -349,16 +349,16 @@ func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error)
return 0, tcpip.ControlMessages{}, nil
}
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
// Connect implements tcpip.Endpoint.Connect.
func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- if addr.Addr == "" {
- // AF_UNSPEC isn't supported.
- return tcpip.ErrAddressFamilyNotSupported
- }
-
if ep.closed {
return tcpip.ErrInvalidEndpointState
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index e5f835c20..24b32e4af 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1362,13 +1362,13 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol
return netProto, nil
}
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
// Connect connects the endpoint to its peer.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- if addr.Addr == "" && addr.Port == 0 {
- // AF_UNSPEC isn't supported.
- return tcpip.ErrAddressFamilyNotSupported
- }
-
return e.connect(addr, true, true)
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 640bb8667..8b3356406 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -711,7 +711,8 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
return netProto, nil
}
-func (e *endpoint) disconnect() *tcpip.Error {
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (e *endpoint) Disconnect() *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -750,9 +751,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- if addr.Addr == "" {
- return e.disconnect()
- }
if addr.Port == 0 {
// We don't support connecting to port zero.
return tcpip.ErrInvalidEndpointState