summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/BUILD1
-rw-r--r--pkg/sentry/socket/control/BUILD1
-rw-r--r--pkg/sentry/socket/control/control.go78
-rw-r--r--pkg/sentry/socket/control/control_vfs2.go10
-rw-r--r--pkg/sentry/socket/hostinet/socket.go50
-rw-r--r--pkg/sentry/socket/hostinet/socket_vfs2.go3
-rw-r--r--pkg/sentry/socket/hostinet/stack.go21
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go63
-rw-r--r--pkg/sentry/socket/netfilter/ipv4.go4
-rw-r--r--pkg/sentry/socket/netfilter/ipv6.go4
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go52
-rw-r--r--pkg/sentry/socket/netfilter/targets.go217
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go2
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go2
-rw-r--r--pkg/sentry/socket/netlink/provider_vfs2.go2
-rw-r--r--pkg/sentry/socket/netlink/route/protocol.go62
-rw-r--r--pkg/sentry/socket/netlink/socket.go12
-rw-r--r--pkg/sentry/socket/netlink/socket_vfs2.go2
-rw-r--r--pkg/sentry/socket/netstack/netstack.go716
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go10
-rw-r--r--pkg/sentry/socket/netstack/provider.go2
-rw-r--r--pkg/sentry/socket/netstack/provider_vfs2.go2
-rw-r--r--pkg/sentry/socket/netstack/stack.go115
-rw-r--r--pkg/sentry/socket/socket.go249
-rw-r--r--pkg/sentry/socket/unix/BUILD5
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD3
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go35
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go3
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go99
-rw-r--r--pkg/sentry/socket/unix/unix.go15
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go5
31 files changed, 1142 insertions, 703 deletions
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index a3f775d15..cc1f6bfcc 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -20,6 +20,7 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/syserr",
"//pkg/tcpip",
+ "//pkg/tcpip/header",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
index ca16d0381..fb7c5dc61 100644
--- a/pkg/sentry/socket/control/BUILD
+++ b/pkg/sentry/socket/control/BUILD
@@ -23,7 +23,6 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
"//pkg/syserror",
- "//pkg/tcpip",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 70ccf77a7..b88cdca48 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -344,21 +343,34 @@ func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte {
}
// PackIPPacketInfo packs an IP_PKTINFO socket control message.
-func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte {
- var p linux.ControlMessageIPPacketInfo
- p.NIC = int32(packetInfo.NIC)
- copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
- copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
-
+func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketInfo, buf []byte) []byte {
return putCmsgStruct(
buf,
linux.SOL_IP,
linux.IP_PKTINFO,
t.Arch().Width(),
- p,
+ packetInfo,
)
}
+// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message.
+func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte {
+ var level uint32
+ var optType uint32
+ switch originalDstAddress.(type) {
+ case *linux.SockAddrInet:
+ level = linux.SOL_IP
+ optType = linux.IP_RECVORIGDSTADDR
+ case *linux.SockAddrInet6:
+ level = linux.SOL_IPV6
+ optType = linux.IPV6_RECVORIGDSTADDR
+ default:
+ panic("invalid address type, must be an IP address for IP_RECVORIGINALDSTADDR cmsg")
+ }
+ return putCmsgStruct(
+ buf, level, optType, t.Arch().Width(), originalDstAddress)
+}
+
// PackControlMessages packs control messages into the given buffer.
//
// We skip control messages specific to Unix domain sockets.
@@ -384,7 +396,11 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt
}
if cmsgs.IP.HasIPPacketInfo {
- buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf)
+ buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf)
+ }
+
+ if cmsgs.IP.OriginalDstAddress != nil {
+ buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf)
}
return buf
@@ -416,17 +432,15 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
space += cmsgSpace(t, linux.SizeOfControlMessageTClass)
}
- return space
-}
+ if cmsgs.IP.HasIPPacketInfo {
+ space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo)
+ }
-// NewIPPacketInfo returns the IPPacketInfo struct.
-func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo {
- var p tcpip.IPPacketInfo
- p.NIC = tcpip.NICID(packetInfo.NIC)
- copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:])
- copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:])
+ if cmsgs.IP.OriginalDstAddress != nil {
+ space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes())
+ }
- return p
+ return space
}
// Parse parses a raw socket control message into portable objects.
@@ -489,6 +503,14 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
cmsgs.Unix.Credentials = scmCreds
i += binary.AlignUp(length, width)
+ case linux.SO_TIMESTAMP:
+ if length < linux.SizeOfTimeval {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ cmsgs.IP.HasTimestamp = true
+ binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], usermem.ByteOrder, &cmsgs.IP.Timestamp)
+ i += binary.AlignUp(length, width)
+
default:
// Unknown message type.
return socket.ControlMessages{}, syserror.EINVAL
@@ -512,7 +534,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
var packetInfo linux.ControlMessageIPPacketInfo
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
- cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo)
+ cmsgs.IP.PacketInfo = packetInfo
+ i += binary.AlignUp(length, width)
+
+ case linux.IP_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet
+ if length < addr.SizeBytes() {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr)
+ cmsgs.IP.OriginalDstAddress = &addr
i += binary.AlignUp(length, width)
default:
@@ -528,6 +559,15 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass)
i += binary.AlignUp(length, width)
+ case linux.IPV6_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet6
+ if length < addr.SizeBytes() {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr)
+ cmsgs.IP.OriginalDstAddress = &addr
+ i += binary.AlignUp(length, width)
+
default:
return socket.ControlMessages{}, syserror.EINVAL
}
diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go
index d9621968c..37d02948f 100644
--- a/pkg/sentry/socket/control/control_vfs2.go
+++ b/pkg/sentry/socket/control/control_vfs2.go
@@ -24,6 +24,8 @@ import (
)
// SCMRightsVFS2 represents a SCM_RIGHTS socket control message.
+//
+// +stateify savable
type SCMRightsVFS2 interface {
transport.RightsControlMessage
@@ -34,9 +36,11 @@ type SCMRightsVFS2 interface {
Files(ctx context.Context, max int) (rf RightsFilesVFS2, truncated bool)
}
-// RightsFiles represents a SCM_RIGHTS socket control message. A reference is
-// maintained for each vfs.FileDescription and is release either when an FD is created or
-// when the Release method is called.
+// RightsFilesVFS2 represents a SCM_RIGHTS socket control message. A reference
+// is maintained for each vfs.FileDescription and is release either when an FD
+// is created or when the Release method is called.
+//
+// +stateify savable
type RightsFilesVFS2 []*vfs.FileDescription
// NewSCMRightsVFS2 creates a new SCM_RIGHTS socket control message
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 7d3c4a01c..be418df2e 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -331,17 +331,17 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO:
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR:
optlen = sizeofInt32
}
case linux.SOL_IPV6:
switch name {
- case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR:
optlen = sizeofInt32
}
case linux.SOL_SOCKET:
switch name {
- case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
+ case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP:
optlen = sizeofInt32
case linux.SO_LINGER:
optlen = syscall.SizeofLinger
@@ -377,24 +377,24 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_TOS, linux.IP_RECVTOS:
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_RECVORIGDSTADDR:
optlen = sizeofInt32
case linux.IP_PKTINFO:
optlen = linux.SizeOfControlMessageIPPacketInfo
}
case linux.SOL_IPV6:
switch name {
- case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR:
optlen = sizeofInt32
}
case linux.SOL_SOCKET:
switch name {
- case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
+ case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP:
optlen = sizeofInt32
}
case linux.SOL_TCP:
switch name {
- case linux.TCP_NODELAY:
+ case linux.TCP_NODELAY, linux.TCP_INQ:
optlen = sizeofInt32
}
}
@@ -513,24 +513,48 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
controlMessages := socket.ControlMessages{}
for _, unixCmsg := range unixControlMessages {
switch unixCmsg.Header.Level {
- case syscall.SOL_IP:
+ case linux.SOL_SOCKET:
switch unixCmsg.Header.Type {
- case syscall.IP_TOS:
+ case linux.SO_TIMESTAMP:
+ controlMessages.IP.HasTimestamp = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfTimeval], usermem.ByteOrder, &controlMessages.IP.Timestamp)
+ }
+
+ case linux.SOL_IP:
+ switch unixCmsg.Header.Type {
+ case linux.IP_TOS:
controlMessages.IP.HasTOS = true
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS)
- case syscall.IP_PKTINFO:
+ case linux.IP_PKTINFO:
controlMessages.IP.HasIPPacketInfo = true
var packetInfo linux.ControlMessageIPPacketInfo
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
- controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo)
+ controlMessages.IP.PacketInfo = packetInfo
+
+ case linux.IP_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet
+ binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr)
+ controlMessages.IP.OriginalDstAddress = &addr
}
- case syscall.SOL_IPV6:
+ case linux.SOL_IPV6:
switch unixCmsg.Header.Type {
- case syscall.IPV6_TCLASS:
+ case linux.IPV6_TCLASS:
controlMessages.IP.HasTClass = true
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass)
+
+ case linux.IPV6_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet6
+ binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr)
+ controlMessages.IP.OriginalDstAddress = &addr
+ }
+
+ case linux.SOL_TCP:
+ switch unixCmsg.Header.Type {
+ case linux.TCP_INQ:
+ controlMessages.IP.HasInq = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], usermem.ByteOrder, &controlMessages.IP.Inq)
}
}
}
diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go
index 163af329b..9a2cac40b 100644
--- a/pkg/sentry/socket/hostinet/socket_vfs2.go
+++ b/pkg/sentry/socket/hostinet/socket_vfs2.go
@@ -33,6 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// +stateify savable
type socketVFS2 struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
@@ -51,7 +52,7 @@ var _ = socket.SocketVFS2(&socketVFS2{})
func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol int, fd int, flags uint32) (*vfs.FileDescription, *syserr.Error) {
mnt := t.Kernel().SocketMount()
- d := sockfs.NewDentry(t.Credentials(), mnt)
+ d := sockfs.NewDentry(t, mnt)
defer d.DecRef(t)
s := &socketVFS2{
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index faa61160e..7e7857ac3 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -324,7 +324,12 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
}
// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
-func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+func (s *Stack) AddInterfaceAddr(int32, inet.InterfaceAddr) error {
+ return syserror.EACCES
+}
+
+// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr.
+func (s *Stack) RemoveInterfaceAddr(int32, inet.InterfaceAddr) error {
return syserror.EACCES
}
@@ -359,7 +364,7 @@ func (s *Stack) TCPSACKEnabled() (bool, error) {
}
// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled.
-func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
+func (s *Stack) SetTCPSACKEnabled(bool) error {
return syserror.EACCES
}
@@ -369,7 +374,7 @@ func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) {
}
// SetTCPRecovery implements inet.Stack.SetTCPRecovery.
-func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error {
+func (s *Stack) SetTCPRecovery(inet.TCPLossRecovery) error {
return syserror.EACCES
}
@@ -430,18 +435,18 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
}
if rawLine == "" {
- return fmt.Errorf("Failed to get raw line")
+ return fmt.Errorf("failed to get raw line")
}
parts := strings.SplitN(rawLine, ":", 2)
if len(parts) != 2 {
- return fmt.Errorf("Failed to get prefix from: %q", rawLine)
+ return fmt.Errorf("failed to get prefix from: %q", rawLine)
}
sliceStat = toSlice(stat)
fields := strings.Fields(strings.TrimSpace(parts[1]))
if len(fields) != len(sliceStat) {
- return fmt.Errorf("Failed to parse fields: %q", rawLine)
+ return fmt.Errorf("failed to parse fields: %q", rawLine)
}
if _, ok := stat.(*inet.StatSNMPTCP); ok {
snmpTCP = true
@@ -457,7 +462,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
sliceStat[i], err = strconv.ParseUint(fields[i], 10, 64)
}
if err != nil {
- return fmt.Errorf("Failed to parse field %d from: %q, %v", i, rawLine, err)
+ return fmt.Errorf("failed to parse field %d from: %q, %v", i, rawLine, err)
}
}
@@ -495,6 +500,6 @@ func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
}
// SetForwarding implements inet.Stack.SetForwarding.
-func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error {
+func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error {
return syserror.EACCES
}
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
index 549787955..e0976fed0 100644
--- a/pkg/sentry/socket/netfilter/extensions.go
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -100,24 +100,43 @@ func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf
// marshalTarget and unmarshalTarget can be used.
type targetMaker interface {
// id uniquely identifies the target.
- id() stack.TargetID
+ id() targetID
- // marshal converts from a stack.Target to an ABI struct.
- marshal(target stack.Target) []byte
+ // marshal converts from a target to an ABI struct.
+ marshal(target target) []byte
- // unmarshal converts from the ABI matcher struct to a stack.Target.
- unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error)
+ // unmarshal converts from the ABI matcher struct to a target.
+ unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error)
}
-// targetMakers maps the TargetID of supported targets to the targetMaker that
+// A targetID uniquely identifies a target.
+type targetID struct {
+ // name is the target name as stored in the xt_entry_target struct.
+ name string
+
+ // networkProtocol is the protocol to which the target applies.
+ networkProtocol tcpip.NetworkProtocolNumber
+
+ // revision is the version of the target.
+ revision uint8
+}
+
+// target extends a stack.Target, allowing it to be used with the extension
+// system. The sentry only uses targets, never stack.Targets directly.
+type target interface {
+ stack.Target
+ id() targetID
+}
+
+// targetMakers maps the targetID of supported targets to the targetMaker that
// marshals and unmarshals it. It is immutable after package initialization.
-var targetMakers = map[stack.TargetID]targetMaker{}
+var targetMakers = map[targetID]targetMaker{}
func targetRevision(name string, netProto tcpip.NetworkProtocolNumber, rev uint8) (uint8, bool) {
- tid := stack.TargetID{
- Name: name,
- NetworkProtocol: netProto,
- Revision: rev,
+ tid := targetID{
+ name: name,
+ networkProtocol: netProto,
+ revision: rev,
}
if _, ok := targetMakers[tid]; !ok {
return 0, false
@@ -126,8 +145,8 @@ func targetRevision(name string, netProto tcpip.NetworkProtocolNumber, rev uint8
// Return the highest supported revision unless rev is higher.
for _, other := range targetMakers {
otherID := other.id()
- if name == otherID.Name && netProto == otherID.NetworkProtocol && otherID.Revision > rev {
- rev = uint8(otherID.Revision)
+ if name == otherID.name && netProto == otherID.networkProtocol && otherID.revision > rev {
+ rev = uint8(otherID.revision)
}
}
return rev, true
@@ -142,19 +161,21 @@ func registerTargetMaker(tm targetMaker) {
targetMakers[tm.id()] = tm
}
-func marshalTarget(target stack.Target) []byte {
- targetMaker, ok := targetMakers[target.ID()]
+func marshalTarget(tgt stack.Target) []byte {
+ // The sentry only uses targets, never stack.Targets directly.
+ target := tgt.(target)
+ targetMaker, ok := targetMakers[target.id()]
if !ok {
- panic(fmt.Sprintf("unknown target of type %T with id %+v.", target, target.ID()))
+ panic(fmt.Sprintf("unknown target of type %T with id %+v.", target, target.id()))
}
return targetMaker.marshal(target)
}
-func unmarshalTarget(target linux.XTEntryTarget, filter stack.IPHeaderFilter, buf []byte) (stack.Target, *syserr.Error) {
- tid := stack.TargetID{
- Name: target.Name.String(),
- NetworkProtocol: filter.NetworkProtocol(),
- Revision: target.Revision,
+func unmarshalTarget(target linux.XTEntryTarget, filter stack.IPHeaderFilter, buf []byte) (target, *syserr.Error) {
+ tid := targetID{
+ name: target.Name.String(),
+ networkProtocol: filter.NetworkProtocol(),
+ revision: target.Revision,
}
targetMaker, ok := targetMakers[tid]
if !ok {
diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go
index b560fae0d..70c561cce 100644
--- a/pkg/sentry/socket/netfilter/ipv4.go
+++ b/pkg/sentry/socket/netfilter/ipv4.go
@@ -46,13 +46,13 @@ func convertNetstackToBinary4(stk *stack.Stack, tablename linux.TableName) (linu
return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename)
}
- table, ok := stk.IPTables().GetTable(tablename.String(), false)
+ id, ok := nameToID[tablename.String()]
if !ok {
return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename)
}
// Setup the info struct.
- entries, info := getEntries4(table, tablename)
+ entries, info := getEntries4(stk.IPTables().GetTable(id, false), tablename)
return entries, info, nil
}
diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go
index 4253f7bf4..5dbb604f0 100644
--- a/pkg/sentry/socket/netfilter/ipv6.go
+++ b/pkg/sentry/socket/netfilter/ipv6.go
@@ -46,13 +46,13 @@ func convertNetstackToBinary6(stk *stack.Stack, tablename linux.TableName) (linu
return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename)
}
- table, ok := stk.IPTables().GetTable(tablename.String(), true)
+ id, ok := nameToID[tablename.String()]
if !ok {
return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename)
}
// Setup the info struct, which is the same in IPv4 and IPv6.
- entries, info := getEntries6(table, tablename)
+ entries, info := getEntries6(stk.IPTables().GetTable(id, true), tablename)
return entries, info, nil
}
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 904a12e38..b283d7229 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -42,6 +42,45 @@ func nflog(format string, args ...interface{}) {
}
}
+// Table names.
+const (
+ natTable = "nat"
+ mangleTable = "mangle"
+ filterTable = "filter"
+)
+
+// nameToID is immutable.
+var nameToID = map[string]stack.TableID{
+ natTable: stack.NATID,
+ mangleTable: stack.MangleID,
+ filterTable: stack.FilterID,
+}
+
+// DefaultLinuxTables returns the rules of stack.DefaultTables() wrapped for
+// compatibility with netfilter extensions.
+func DefaultLinuxTables() *stack.IPTables {
+ tables := stack.DefaultTables()
+ tables.VisitTargets(func(oldTarget stack.Target) stack.Target {
+ switch val := oldTarget.(type) {
+ case *stack.AcceptTarget:
+ return &acceptTarget{AcceptTarget: *val}
+ case *stack.DropTarget:
+ return &dropTarget{DropTarget: *val}
+ case *stack.ErrorTarget:
+ return &errorTarget{ErrorTarget: *val}
+ case *stack.UserChainTarget:
+ return &userChainTarget{UserChainTarget: *val}
+ case *stack.ReturnTarget:
+ return &returnTarget{ReturnTarget: *val}
+ case *stack.RedirectTarget:
+ return &redirectTarget{RedirectTarget: *val}
+ default:
+ panic(fmt.Sprintf("Unknown rule in default iptables of type %T", val))
+ }
+ })
+ return tables
+}
+
// GetInfo returns information about iptables.
func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, ipv6 bool) (linux.IPTGetinfo, *syserr.Error) {
// Read in the struct and table name.
@@ -144,9 +183,9 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
// TODO(gvisor.dev/issue/170): Support other tables.
var table stack.Table
switch replace.Name.String() {
- case stack.FilterTable:
+ case filterTable:
table = stack.EmptyFilterTable()
- case stack.NATTable:
+ case natTable:
table = stack.EmptyNATTable()
default:
nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
@@ -177,7 +216,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
}
if offset == replace.Underflow[hook] {
if !validUnderflow(table.Rules[ruleIdx], ipv6) {
- nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP", ruleIdx)
+ nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP: %+v", ruleIdx)
return syserr.ErrInvalidArgument
}
table.Underflows[hk] = ruleIdx
@@ -253,8 +292,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
// - There are no chains without an unconditional final rule.
// - There are no chains without an unconditional underflow rule.
- return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table, ipv6))
-
+ return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(nameToID[replace.Name.String()], table, ipv6))
}
// parseMatchers parses 0 or more matchers from optVal. optVal should contain
@@ -308,7 +346,7 @@ func validUnderflow(rule stack.Rule, ipv6 bool) bool {
return false
}
switch rule.Target.(type) {
- case *stack.AcceptTarget, *stack.DropTarget:
+ case *acceptTarget, *dropTarget:
return true
default:
return false
@@ -319,7 +357,7 @@ func isUnconditionalAccept(rule stack.Rule, ipv6 bool) bool {
if !validUnderflow(rule, ipv6) {
return false
}
- _, ok := rule.Target.(*stack.AcceptTarget)
+ _, ok := rule.Target.(*acceptTarget)
return ok
}
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index 0e14447fe..f2653d523 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -26,6 +26,15 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// ErrorTargetName is used to mark targets as error targets. Error targets
+// shouldn't be reached - an error has occurred if we fall through to one.
+const ErrorTargetName = "ERROR"
+
+// RedirectTargetName is used to mark targets as redirect targets. Redirect
+// targets should be reached for only NAT and Mangle tables. These targets will
+// change the destination port and/or IP for packets.
+const RedirectTargetName = "REDIRECT"
+
func init() {
// Standard targets include ACCEPT, DROP, RETURN, and JUMP.
registerTargetMaker(&standardTargetMaker{
@@ -52,25 +61,96 @@ func init() {
})
}
+// The stack package provides some basic, useful targets for us. The following
+// types wrap them for compatibility with the extension system.
+
+type acceptTarget struct {
+ stack.AcceptTarget
+}
+
+func (at *acceptTarget) id() targetID {
+ return targetID{
+ networkProtocol: at.NetworkProtocol,
+ }
+}
+
+type dropTarget struct {
+ stack.DropTarget
+}
+
+func (dt *dropTarget) id() targetID {
+ return targetID{
+ networkProtocol: dt.NetworkProtocol,
+ }
+}
+
+type errorTarget struct {
+ stack.ErrorTarget
+}
+
+func (et *errorTarget) id() targetID {
+ return targetID{
+ name: ErrorTargetName,
+ networkProtocol: et.NetworkProtocol,
+ }
+}
+
+type userChainTarget struct {
+ stack.UserChainTarget
+}
+
+func (uc *userChainTarget) id() targetID {
+ return targetID{
+ name: ErrorTargetName,
+ networkProtocol: uc.NetworkProtocol,
+ }
+}
+
+type returnTarget struct {
+ stack.ReturnTarget
+}
+
+func (rt *returnTarget) id() targetID {
+ return targetID{
+ networkProtocol: rt.NetworkProtocol,
+ }
+}
+
+type redirectTarget struct {
+ stack.RedirectTarget
+
+ // addr must be (un)marshalled when reading and writing the target to
+ // userspace, but does not affect behavior.
+ addr tcpip.Address
+}
+
+func (rt *redirectTarget) id() targetID {
+ return targetID{
+ name: RedirectTargetName,
+ networkProtocol: rt.NetworkProtocol,
+ }
+}
+
type standardTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-func (sm *standardTargetMaker) id() stack.TargetID {
+func (sm *standardTargetMaker) id() targetID {
// Standard targets have the empty string as a name and no revisions.
- return stack.TargetID{
- NetworkProtocol: sm.NetworkProtocol,
+ return targetID{
+ networkProtocol: sm.NetworkProtocol,
}
}
-func (*standardTargetMaker) marshal(target stack.Target) []byte {
+
+func (*standardTargetMaker) marshal(target target) []byte {
// Translate verdicts the same way as the iptables tool.
var verdict int32
switch tg := target.(type) {
- case *stack.AcceptTarget:
+ case *acceptTarget:
verdict = -linux.NF_ACCEPT - 1
- case *stack.DropTarget:
+ case *dropTarget:
verdict = -linux.NF_DROP - 1
- case *stack.ReturnTarget:
+ case *returnTarget:
verdict = linux.NF_RETURN
case *JumpTarget:
verdict = int32(tg.Offset)
@@ -90,7 +170,7 @@ func (*standardTargetMaker) marshal(target stack.Target) []byte {
return binary.Marshal(ret, usermem.ByteOrder, xt)
}
-func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
if len(buf) != linux.SizeOfXTStandardTarget {
nflog("buf has wrong size for standard target %d", len(buf))
return nil, syserr.ErrInvalidArgument
@@ -114,20 +194,20 @@ type errorTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-func (em *errorTargetMaker) id() stack.TargetID {
+func (em *errorTargetMaker) id() targetID {
// Error targets have no revision.
- return stack.TargetID{
- Name: stack.ErrorTargetName,
- NetworkProtocol: em.NetworkProtocol,
+ return targetID{
+ name: ErrorTargetName,
+ networkProtocol: em.NetworkProtocol,
}
}
-func (*errorTargetMaker) marshal(target stack.Target) []byte {
+func (*errorTargetMaker) marshal(target target) []byte {
var errorName string
switch tg := target.(type) {
- case *stack.ErrorTarget:
- errorName = stack.ErrorTargetName
- case *stack.UserChainTarget:
+ case *errorTarget:
+ errorName = ErrorTargetName
+ case *userChainTarget:
errorName = tg.Name
default:
panic(fmt.Sprintf("errorMakerTarget cannot marshal unknown type %T", target))
@@ -140,37 +220,38 @@ func (*errorTargetMaker) marshal(target stack.Target) []byte {
},
}
copy(xt.Name[:], errorName)
- copy(xt.Target.Name[:], stack.ErrorTargetName)
+ copy(xt.Target.Name[:], ErrorTargetName)
ret := make([]byte, 0, linux.SizeOfXTErrorTarget)
return binary.Marshal(ret, usermem.ByteOrder, xt)
}
-func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
if len(buf) != linux.SizeOfXTErrorTarget {
nflog("buf has insufficient size for error target %d", len(buf))
return nil, syserr.ErrInvalidArgument
}
- var errorTarget linux.XTErrorTarget
+ var errTgt linux.XTErrorTarget
buf = buf[:linux.SizeOfXTErrorTarget]
- binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget)
+ binary.Unmarshal(buf, usermem.ByteOrder, &errTgt)
// Error targets are used in 2 cases:
- // * An actual error case. These rules have an error
- // named stack.ErrorTargetName. The last entry of the table
- // is usually an error case to catch any packets that
- // somehow fall through every rule.
+ // * An actual error case. These rules have an error named
+ // ErrorTargetName. The last entry of the table is usually an error
+ // case to catch any packets that somehow fall through every rule.
// * To mark the start of a user defined chain. These
// rules have an error with the name of the chain.
- switch name := errorTarget.Name.String(); name {
- case stack.ErrorTargetName:
- return &stack.ErrorTarget{NetworkProtocol: filter.NetworkProtocol()}, nil
+ switch name := errTgt.Name.String(); name {
+ case ErrorTargetName:
+ return &errorTarget{stack.ErrorTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ }}, nil
default:
// User defined chain.
- return &stack.UserChainTarget{
+ return &userChainTarget{stack.UserChainTarget{
Name: name,
NetworkProtocol: filter.NetworkProtocol(),
- }, nil
+ }}, nil
}
}
@@ -178,22 +259,22 @@ type redirectTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-func (rm *redirectTargetMaker) id() stack.TargetID {
- return stack.TargetID{
- Name: stack.RedirectTargetName,
- NetworkProtocol: rm.NetworkProtocol,
+func (rm *redirectTargetMaker) id() targetID {
+ return targetID{
+ name: RedirectTargetName,
+ networkProtocol: rm.NetworkProtocol,
}
}
-func (*redirectTargetMaker) marshal(target stack.Target) []byte {
- rt := target.(*stack.RedirectTarget)
+func (*redirectTargetMaker) marshal(target target) []byte {
+ rt := target.(*redirectTarget)
// This is a redirect target named redirect
xt := linux.XTRedirectTarget{
Target: linux.XTEntryTarget{
TargetSize: linux.SizeOfXTRedirectTarget,
},
}
- copy(xt.Target.Name[:], stack.RedirectTargetName)
+ copy(xt.Target.Name[:], RedirectTargetName)
ret := make([]byte, 0, linux.SizeOfXTRedirectTarget)
xt.NfRange.RangeSize = 1
@@ -203,7 +284,7 @@ func (*redirectTargetMaker) marshal(target stack.Target) []byte {
return binary.Marshal(ret, usermem.ByteOrder, xt)
}
-func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
if len(buf) < linux.SizeOfXTRedirectTarget {
nflog("redirectTargetMaker: buf has insufficient size for redirect target %d", len(buf))
return nil, syserr.ErrInvalidArgument
@@ -214,15 +295,17 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
return nil, syserr.ErrInvalidArgument
}
- var redirectTarget linux.XTRedirectTarget
+ var rt linux.XTRedirectTarget
buf = buf[:linux.SizeOfXTRedirectTarget]
- binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget)
+ binary.Unmarshal(buf, usermem.ByteOrder, &rt)
// Copy linux.XTRedirectTarget to stack.RedirectTarget.
- target := stack.RedirectTarget{NetworkProtocol: filter.NetworkProtocol()}
+ target := redirectTarget{RedirectTarget: stack.RedirectTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ }}
// RangeSize should be 1.
- nfRange := redirectTarget.NfRange
+ nfRange := rt.NfRange
if nfRange.RangeSize != 1 {
nflog("redirectTargetMaker: bad rangesize %d", nfRange.RangeSize)
return nil, syserr.ErrInvalidArgument
@@ -247,7 +330,7 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
return nil, syserr.ErrInvalidArgument
}
- target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
+ target.addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
target.Port = ntohs(nfRange.RangeIPV4.MinPort)
return &target, nil
@@ -264,15 +347,15 @@ type nfNATTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-func (rm *nfNATTargetMaker) id() stack.TargetID {
- return stack.TargetID{
- Name: stack.RedirectTargetName,
- NetworkProtocol: rm.NetworkProtocol,
+func (rm *nfNATTargetMaker) id() targetID {
+ return targetID{
+ name: RedirectTargetName,
+ networkProtocol: rm.NetworkProtocol,
}
}
-func (*nfNATTargetMaker) marshal(target stack.Target) []byte {
- rt := target.(*stack.RedirectTarget)
+func (*nfNATTargetMaker) marshal(target target) []byte {
+ rt := target.(*redirectTarget)
nt := nfNATTarget{
Target: linux.XTEntryTarget{
TargetSize: nfNATMarhsalledSize,
@@ -281,9 +364,9 @@ func (*nfNATTargetMaker) marshal(target stack.Target) []byte {
Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED,
},
}
- copy(nt.Target.Name[:], stack.RedirectTargetName)
- copy(nt.Range.MinAddr[:], rt.Addr)
- copy(nt.Range.MaxAddr[:], rt.Addr)
+ copy(nt.Target.Name[:], RedirectTargetName)
+ copy(nt.Range.MinAddr[:], rt.addr)
+ copy(nt.Range.MaxAddr[:], rt.addr)
nt.Range.MinProto = htons(rt.Port)
nt.Range.MaxProto = nt.Range.MinProto
@@ -292,7 +375,7 @@ func (*nfNATTargetMaker) marshal(target stack.Target) []byte {
return binary.Marshal(ret, usermem.ByteOrder, nt)
}
-func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
if size := nfNATMarhsalledSize; len(buf) < size {
nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size)
return nil, syserr.ErrInvalidArgument
@@ -324,10 +407,12 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (sta
return nil, syserr.ErrInvalidArgument
}
- target := stack.RedirectTarget{
- NetworkProtocol: filter.NetworkProtocol(),
- Addr: tcpip.Address(natRange.MinAddr[:]),
- Port: ntohs(natRange.MinProto),
+ target := redirectTarget{
+ RedirectTarget: stack.RedirectTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ Port: ntohs(natRange.MinProto),
+ },
+ addr: tcpip.Address(natRange.MinAddr[:]),
}
return &target, nil
@@ -335,18 +420,24 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (sta
// translateToStandardTarget translates from the value in a
// linux.XTStandardTarget to an stack.Verdict.
-func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (stack.Target, *syserr.Error) {
+func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (target, *syserr.Error) {
// TODO(gvisor.dev/issue/170): Support other verdicts.
switch val {
case -linux.NF_ACCEPT - 1:
- return &stack.AcceptTarget{NetworkProtocol: netProto}, nil
+ return &acceptTarget{stack.AcceptTarget{
+ NetworkProtocol: netProto,
+ }}, nil
case -linux.NF_DROP - 1:
- return &stack.DropTarget{NetworkProtocol: netProto}, nil
+ return &dropTarget{stack.DropTarget{
+ NetworkProtocol: netProto,
+ }}, nil
case -linux.NF_QUEUE - 1:
nflog("unsupported iptables verdict QUEUE")
return nil, syserr.ErrInvalidArgument
case linux.NF_RETURN:
- return &stack.ReturnTarget{NetworkProtocol: netProto}, nil
+ return &returnTarget{stack.ReturnTarget{
+ NetworkProtocol: netProto,
+ }}, nil
default:
nflog("unknown iptables verdict %d", val)
return nil, syserr.ErrInvalidArgument
@@ -382,9 +473,9 @@ type JumpTarget struct {
}
// ID implements Target.ID.
-func (jt *JumpTarget) ID() stack.TargetID {
- return stack.TargetID{
- NetworkProtocol: jt.NetworkProtocol,
+func (jt *JumpTarget) id() targetID {
+ return targetID{
+ networkProtocol: jt.NetworkProtocol,
}
}
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index 844acfede..352c51390 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -71,7 +71,7 @@ func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma
}
if filter.Protocol != header.TCPProtocolNumber {
- return nil, fmt.Errorf("TCP matching is only valid for protocol %d.", header.TCPProtocolNumber)
+ return nil, fmt.Errorf("TCP matching is only valid for protocol %d", header.TCPProtocolNumber)
}
return &TCPMatcher{
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index 63201201c..c88d8268d 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -68,7 +68,7 @@ func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma
}
if filter.Protocol != header.UDPProtocolNumber {
- return nil, fmt.Errorf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber)
+ return nil, fmt.Errorf("UDP matching is only valid for protocol %d", header.UDPProtocolNumber)
}
return &UDPMatcher{
diff --git a/pkg/sentry/socket/netlink/provider_vfs2.go b/pkg/sentry/socket/netlink/provider_vfs2.go
index e8930f031..f061c5d62 100644
--- a/pkg/sentry/socket/netlink/provider_vfs2.go
+++ b/pkg/sentry/socket/netlink/provider_vfs2.go
@@ -51,7 +51,7 @@ func (*socketProviderVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol
vfsfd := &s.vfsfd
mnt := t.Kernel().SocketMount()
- d := sockfs.NewDentry(t.Credentials(), mnt)
+ d := sockfs.NewDentry(t, mnt)
defer d.DecRef(t)
if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{
DenyPRead: true,
diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go
index c84d8bd7c..f4d034c13 100644
--- a/pkg/sentry/socket/netlink/route/protocol.go
+++ b/pkg/sentry/socket/netlink/route/protocol.go
@@ -36,9 +36,9 @@ type commandKind int
const (
kindNew commandKind = 0x0
- kindDel = 0x1
- kindGet = 0x2
- kindSet = 0x3
+ kindDel commandKind = 0x1
+ kindGet commandKind = 0x2
+ kindSet commandKind = 0x3
)
func typeKind(typ uint16) commandKind {
@@ -423,6 +423,11 @@ func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlin
}
attrs = rest
+ // NOTE: A netlink message will contain multiple header attributes.
+ // Both the IFA_ADDRESS and IFA_LOCAL attributes are typically sent
+ // with IFA_ADDRESS being a prefix address and IFA_LOCAL being the
+ // local interface address. We add the local interface address here
+ // and ignore the IFA_ADDRESS.
switch ahdr.Type {
case linux.IFA_LOCAL:
err := stack.AddInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{
@@ -439,11 +444,60 @@ func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlin
} else if err != nil {
return syserr.ErrInvalidArgument
}
+ case linux.IFA_ADDRESS:
+ default:
+ return syserr.ErrNotSupported
}
}
return nil
}
+// delAddr handles RTM_DELADDR requests.
+func (p *Protocol) delAddr(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network stack.
+ return syserr.ErrProtocolNotSupported
+ }
+
+ var ifa linux.InterfaceAddrMessage
+ attrs, ok := msg.GetData(&ifa)
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+
+ for !attrs.Empty() {
+ ahdr, value, rest, ok := attrs.ParseFirst()
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+ attrs = rest
+
+ // NOTE: A netlink message will contain multiple header attributes.
+ // Both the IFA_ADDRESS and IFA_LOCAL attributes are typically sent
+ // with IFA_ADDRESS being a prefix address and IFA_LOCAL being the
+ // local interface address. We use the local interface address to
+ // remove the address and ignore the IFA_ADDRESS.
+ switch ahdr.Type {
+ case linux.IFA_LOCAL:
+ err := stack.RemoveInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{
+ Family: ifa.Family,
+ PrefixLen: ifa.PrefixLen,
+ Flags: ifa.Flags,
+ Addr: value,
+ })
+ if err != nil {
+ return syserr.ErrBadLocalAddress
+ }
+ case linux.IFA_ADDRESS:
+ default:
+ return syserr.ErrNotSupported
+ }
+ }
+
+ return nil
+}
+
// ProcessMessage implements netlink.Protocol.ProcessMessage.
func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
hdr := msg.Header()
@@ -485,6 +539,8 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms
return p.dumpRoutes(ctx, msg, ms)
case linux.RTM_NEWADDR:
return p.newAddr(ctx, msg, ms)
+ case linux.RTM_DELADDR:
+ return p.delAddr(ctx, msg, ms)
default:
return syserr.ErrNotSupported
}
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 3baad098b..057f4d294 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -120,9 +120,6 @@ type socketOpsCommon struct {
// fixed buffer but only consume this many bytes.
sendBufferSize uint32
- // passcred indicates if this socket wants SCM credentials.
- passcred bool
-
// filter indicates that this socket has a BPF filter "installed".
//
// TODO(gvisor.dev/issue/1119): We don't actually support filtering,
@@ -201,10 +198,7 @@ func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
// Passcred implements transport.Credentialer.Passcred.
func (s *socketOpsCommon) Passcred() bool {
- s.mu.Lock()
- passcred := s.passcred
- s.mu.Unlock()
- return passcred
+ return s.ep.SocketOptions().GetPassCred()
}
// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
@@ -419,9 +413,7 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []
}
passcred := usermem.ByteOrder.Uint32(opt)
- s.mu.Lock()
- s.passcred = passcred != 0
- s.mu.Unlock()
+ s.ep.SocketOptions().SetPassCred(passcred != 0)
return nil
case linux.SO_ATTACH_FILTER:
diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go
index c83b23242..461d524e5 100644
--- a/pkg/sentry/socket/netlink/socket_vfs2.go
+++ b/pkg/sentry/socket/netlink/socket_vfs2.go
@@ -37,6 +37,8 @@ import (
// to/from the kernel.
//
// SocketVFS2 implements socket.SocketVFS2 and transport.Credentialer.
+//
+// +stateify savable
type SocketVFS2 struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 211f07947..23d5cab9c 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -84,69 +84,95 @@ var Metrics = tcpip.Stats{
MalformedRcvdPackets: mustCreateMetric("/netstack/malformed_received_packets", "Number of packets received by netstack that were deemed malformed."),
DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped by netstack due to full queues."),
ICMP: tcpip.ICMPStats{
- V4PacketsSent: tcpip.ICMPv4SentPacketStats{
- ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
- Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."),
- SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."),
- Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."),
- Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."),
- TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."),
- InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."),
- InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."),
+ V4: tcpip.ICMPv4Stats{
+ PacketsSent: tcpip.ICMPv4SentPacketStats{
+ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
+ Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."),
+ },
+ Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."),
+ },
+ PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{
+ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
+ Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."),
+ },
+ Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."),
},
- Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."),
},
- V4PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{
- ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
- Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."),
- SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."),
- Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."),
- Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."),
- TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."),
- InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."),
- InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."),
+ V6: tcpip.ICMPv6Stats{
+ PacketsSent: tcpip.ICMPv6SentPacketStats{
+ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."),
+ },
+ Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."),
+ },
+ PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{
+ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."),
+ },
+ Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."),
},
- Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."),
},
- V6PacketsSent: tcpip.ICMPv6SentPacketStats{
- ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
- EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."),
- PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."),
- RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."),
- RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."),
- NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."),
- NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."),
- RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."),
+ },
+ IGMP: tcpip.IGMPStats{
+ PacketsSent: tcpip.IGMPSentPacketStats{
+ IGMPPacketStats: tcpip.IGMPPacketStats{
+ MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Total number of IGMP Membership Query messages sent by netstack."),
+ V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Total number of IGMPv1 Membership Report messages sent by netstack."),
+ V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Total number of IGMPv2 Membership Report messages sent by netstack."),
+ LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Total number of IGMP Leave Group messages sent by netstack."),
},
- Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."),
+ Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Total number of IGMP packets dropped by netstack due to link layer errors."),
},
- V6PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{
- ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
- EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."),
- PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."),
- RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."),
- RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."),
- NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."),
- NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."),
- RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."),
+ PacketsReceived: tcpip.IGMPReceivedPacketStats{
+ IGMPPacketStats: tcpip.IGMPPacketStats{
+ MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Total number of IGMP Membership Query messages received by netstack."),
+ V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Total number of IGMPv1 Membership Report messages received by netstack."),
+ V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Total number of IGMPv2 Membership Report messages received by netstack."),
+ LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Total number of IGMP Leave Group messages received by netstack."),
},
- Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."),
+ Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Total number of IGMP packets received by netstack that could not be parsed."),
+ ChecksumErrors: mustCreateMetric("/netstack/igmp/packets_received/checksum_errors", "Total number of received IGMP packets with bad checksums."),
+ Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Total number of unrecognized IGMP packets received by netstack."),
},
},
IP: tcpip.IPStats{
@@ -209,18 +235,6 @@ const sizeOfInt32 int = 4
var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL)
-// ntohs converts a 16-bit number from network byte order to host byte order. It
-// assumes that the host is little endian.
-func ntohs(v uint16) uint16 {
- return v<<8 | v>>8
-}
-
-// htons converts a 16-bit number from host byte order to network byte order. It
-// assumes that the host is little endian.
-func htons(v uint16) uint16 {
- return ntohs(v)
-}
-
// commonEndpoint represents the intersection of a tcpip.Endpoint and a
// transport.Endpoint.
type commonEndpoint interface {
@@ -240,10 +254,6 @@ type commonEndpoint interface {
// transport.Endpoint.SetSockOpt.
SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error
- // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and
- // transport.Endpoint.SetSockOptBool.
- SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
-
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and
// transport.Endpoint.SetSockOptInt.
SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
@@ -252,16 +262,21 @@ type commonEndpoint interface {
// transport.Endpoint.GetSockOpt.
GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error
- // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and
- // transport.Endpoint.GetSockOpt.
- GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and
// transport.Endpoint.GetSockOpt.
GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
- // LastError implements tcpip.Endpoint.LastError.
+ // State returns a socket's lifecycle state. The returned value is
+ // protocol-specific and is primarily used for diagnostics.
+ State() uint32
+
+ // LastError implements tcpip.Endpoint.LastError and
+ // transport.Endpoint.LastError.
LastError() *tcpip.Error
+
+ // SocketOptions implements tcpip.Endpoint.SocketOptions and
+ // transport.Endpoint.SocketOptions.
+ SocketOptions() *tcpip.SocketOptions
}
// LINT.IfChange
@@ -305,7 +320,7 @@ type socketOpsCommon struct {
readView buffer.View
// readCM holds control message information for the last packet read
// from Endpoint.
- readCM tcpip.ControlMessages
+ readCM socket.IPControlMessages
sender tcpip.FullAddress
linkPacketInfo tcpip.LinkPacketInfo
@@ -329,9 +344,7 @@ type socketOpsCommon struct {
// New creates a new endpoint socket.
func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) {
if skType == linux.SOCK_STREAM {
- if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ endpoint.SocketOptions().SetDelayOption(true)
}
dirent := socket.NewDirent(t, netstackDevice)
@@ -360,88 +373,6 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
return tcpip.Address(addr)
}
-// AddressAndFamily reads an sockaddr struct from the given address and
-// converts it to the FullAddress format. It supports AF_UNIX, AF_INET,
-// AF_INET6, and AF_PACKET addresses.
-//
-// AddressAndFamily returns an address and its family.
-func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
- // Make sure we have at least 2 bytes for the address family.
- if len(addr) < 2 {
- return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument
- }
-
- // Get the rest of the fields based on the address family.
- switch family := usermem.ByteOrder.Uint16(addr); family {
- case linux.AF_UNIX:
- path := addr[2:]
- if len(path) > linux.UnixPathMax {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
- // Drop the terminating NUL (if one exists) and everything after
- // it for filesystem (non-abstract) addresses.
- if len(path) > 0 && path[0] != 0 {
- if n := bytes.IndexByte(path[1:], 0); n >= 0 {
- path = path[:n+1]
- }
- }
- return tcpip.FullAddress{
- Addr: tcpip.Address(path),
- }, family, nil
-
- case linux.AF_INET:
- var a linux.SockAddrInet
- if len(addr) < sockAddrInetSize {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
- binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
-
- out := tcpip.FullAddress{
- Addr: bytesToIPAddress(a.Addr[:]),
- Port: ntohs(a.Port),
- }
- return out, family, nil
-
- case linux.AF_INET6:
- var a linux.SockAddrInet6
- if len(addr) < sockAddrInet6Size {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
- binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
-
- out := tcpip.FullAddress{
- Addr: bytesToIPAddress(a.Addr[:]),
- Port: ntohs(a.Port),
- }
- if isLinkLocal(out.Addr) {
- out.NIC = tcpip.NICID(a.Scope_id)
- }
- return out, family, nil
-
- case linux.AF_PACKET:
- var a linux.SockAddrLink
- if len(addr) < sockAddrLinkSize {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
- binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a)
- if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
-
- // TODO(gvisor.dev/issue/173): Return protocol too.
- return tcpip.FullAddress{
- NIC: tcpip.NICID(a.InterfaceIndex),
- Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
- }, family, nil
-
- case linux.AF_UNSPEC:
- return tcpip.FullAddress{}, family, nil
-
- default:
- return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported
- }
-}
-
func (s *socketOpsCommon) isPacketBased() bool {
return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW
}
@@ -477,7 +408,7 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error {
}
s.readView = v
- s.readCM = cms
+ s.readCM = socket.NewIPControlMessages(s.family, cms)
atomic.StoreUint32(&s.readViewHasData, 1)
return nil
@@ -497,11 +428,7 @@ func (s *socketOpsCommon) Release(ctx context.Context) {
return
}
- var v tcpip.LingerOption
- if err := s.Endpoint.GetSockOpt(&v); err != nil {
- return
- }
-
+ v := s.Endpoint.SocketOptions().GetLinger()
// The case for zero timeout is handled in tcp endpoint close function.
// Close is blocked until either:
// 1. The endpoint state is not in any of the states: FIN-WAIT1,
@@ -718,11 +645,7 @@ func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error {
return nil
}
if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 {
- v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption)
- if err != nil {
- return syserr.TranslateNetstackError(err)
- }
- if !v {
+ if !s.Endpoint.SocketOptions().GetV6Only() {
return nil
}
}
@@ -746,7 +669,7 @@ func (s *socketOpsCommon) mapFamily(addr tcpip.FullAddress, family uint16) tcpip
// Connect implements the linux syscall connect(2) for sockets backed by
// tpcip.Endpoint.
func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
- addr, family, err := AddressAndFamily(sockaddr)
+ addr, family, err := socket.AddressAndFamily(sockaddr)
if err != nil {
return err
}
@@ -827,7 +750,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
} else {
var err *syserr.Error
- addr, family, err = AddressAndFamily(sockaddr)
+ addr, family, err = socket.AddressAndFamily(sockaddr)
if err != nil {
return err
}
@@ -918,7 +841,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
var addr linux.SockAddr
var addrLen uint32
if peerAddr != nil {
- addr, addrLen = ConvertAddress(s.family, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(s.family, *peerAddr)
}
fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
@@ -1002,7 +925,7 @@ func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family in
return getSockOptSocket(t, s, ep, family, skType, name, outLen)
case linux.SOL_TCP:
- return getSockOptTCP(t, ep, name, outLen)
+ return getSockOptTCP(t, s, ep, name, outLen)
case linux.SOL_IPV6:
return getSockOptIPv6(t, s, ep, name, outPtr, outLen)
@@ -1038,7 +961,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
// Get the last error and convert it.
- err := ep.LastError()
+ err := ep.SocketOptions().GetLastError()
if err == nil {
optP := primitive.Int32(0)
return &optP, nil
@@ -1065,13 +988,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.PasscredOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetPassCred()))
+ return &v, nil
case linux.SO_SNDBUF:
if outLen < sizeOfInt32 {
@@ -1112,25 +1030,16 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReuseAddress()))
+ return &v, nil
case linux.SO_REUSEPORT:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReusePortOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReusePort()))
+ return &v, nil
case linux.SO_BINDTODEVICE:
var v tcpip.BindToDeviceOption
@@ -1163,37 +1072,24 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.BroadcastOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetBroadcast()))
+ return &v, nil
case linux.SO_KEEPALIVE:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.KeepaliveEnabledOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetKeepAlive()))
+ return &v, nil
case linux.SO_LINGER:
if outLen < linux.SizeOfLinger {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.LingerOption
var linger linux.Linger
- if err := ep.GetSockOpt(&v); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ v := ep.SocketOptions().GetLinger()
if v.Enabled {
linger.OnOff = 1
@@ -1224,22 +1120,26 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.OutOfBandInlineOption
- if err := ep.GetSockOpt(&v); err != nil {
- return nil, syserr.TranslateNetstackError(err)
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetOutOfBandInline()))
+ return &v, nil
+
+ case linux.SO_NO_CHECK:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
}
- vP := primitive.Int32(v)
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetNoChecksum()))
+ return &v, nil
- case linux.SO_NO_CHECK:
+ case linux.SO_ACCEPTCONN:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.NoChecksumOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
+ // This option is only viable for TCP endpoints.
+ var v bool
+ if _, skType, skProto := s.Type(); isTCPSocket(skType, skProto) {
+ v = tcp.EndpointState(ep.State()) == tcp.StateListen
}
vP := primitive.Int32(boolToInt32(v))
return &vP, nil
@@ -1251,46 +1151,36 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
// getSockOptTCP implements GetSockOpt when level is SOL_TCP.
-func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
+func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
+ if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) {
+ log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto)
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.TCP_NODELAY:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.DelayOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(!v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(!ep.SocketOptions().GetDelayOption()))
+ return &v, nil
case linux.TCP_CORK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.CorkOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetCorkOption()))
+ return &v, nil
case linux.TCP_QUICKACK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.QuickAckOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetQuickAck()))
+ return &v, nil
case linux.TCP_MAXSEG:
if outLen < sizeOfInt32 {
@@ -1464,19 +1354,24 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal
// getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6.
func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
+ family, skType, _ := s.Type()
+ if family != linux.AF_INET6 {
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IPV6_V6ONLY:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.V6OnlyOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetV6Only()))
+ return &v, nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1508,13 +1403,16 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass()))
+ return &v, nil
+
+ case linux.IPV6_RECVORIGDSTADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
}
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
+ return &v, nil
case linux.IP6T_ORIGINAL_DST:
if outLen < int(binary.Size(linux.SockAddrInet6{})) {
@@ -1526,7 +1424,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.TranslateNetstackError(err)
}
- a, _ := ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v))
+ a, _ := socket.ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v))
return a.(*linux.SockAddrInet6), nil
case linux.IP6T_SO_GET_INFO:
@@ -1535,7 +1433,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return nil, syserr.ErrProtocolNotAvailable
}
@@ -1555,7 +1453,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrInvalidArgument
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return nil, syserr.ErrProtocolNotAvailable
}
@@ -1575,7 +1473,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return nil, syserr.ErrProtocolNotAvailable
}
@@ -1597,6 +1495,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
// getSockOptIP implements GetSockOpt when level is SOL_IP.
func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IP_TTL:
if outLen < sizeOfInt32 {
@@ -1639,7 +1542,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.TranslateNetstackError(err)
}
- a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
+ a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
return &a.(*linux.SockAddrInet).Addr, nil
@@ -1648,13 +1551,8 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetMulticastLoop()))
+ return &v, nil
case linux.IP_TOS:
// Length handling for parity with Linux.
@@ -1678,26 +1576,32 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTOS()))
+ return &v, nil
+
+ case linux.IP_PKTINFO:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
}
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceivePacketInfo()))
+ return &v, nil
- case linux.IP_PKTINFO:
+ case linux.IP_HDRINCL:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded()))
+ return &v, nil
+
+ case linux.IP_RECVORIGDSTADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
}
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
+ return &v, nil
case linux.SO_ORIGINAL_DST:
if outLen < int(binary.Size(linux.SockAddrInet{})) {
@@ -1709,7 +1613,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.TranslateNetstackError(err)
}
- a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v))
+ a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress(v))
return a.(*linux.SockAddrInet), nil
case linux.IPT_SO_GET_INFO:
@@ -1816,7 +1720,7 @@ func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int
return setSockOptSocket(t, s, ep, name, optVal)
case linux.SOL_TCP:
- return setSockOptTCP(t, ep, name, optVal)
+ return setSockOptTCP(t, s, ep, name, optVal)
case linux.SOL_IPV6:
return setSockOptIPv6(t, s, ep, name, optVal)
@@ -1866,7 +1770,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0))
+ ep.SocketOptions().SetReuseAddress(v != 0)
+ return nil
case linux.SO_REUSEPORT:
if len(optVal) < sizeOfInt32 {
@@ -1874,7 +1779,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0))
+ ep.SocketOptions().SetReusePort(v != 0)
+ return nil
case linux.SO_BINDTODEVICE:
n := bytes.IndexByte(optVal, 0)
@@ -1904,7 +1810,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.BroadcastOption, v != 0))
+ ep.SocketOptions().SetBroadcast(v != 0)
+ return nil
case linux.SO_PASSCRED:
if len(optVal) < sizeOfInt32 {
@@ -1912,7 +1819,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0))
+ ep.SocketOptions().SetPassCred(v != 0)
+ return nil
case linux.SO_KEEPALIVE:
if len(optVal) < sizeOfInt32 {
@@ -1920,7 +1828,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.KeepaliveEnabledOption, v != 0))
+ ep.SocketOptions().SetKeepAlive(v != 0)
+ return nil
case linux.SO_SNDTIMEO:
if len(optVal) < linux.SizeOfTimeval {
@@ -1959,8 +1868,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
- opt := tcpip.OutOfBandInlineOption(v)
- return syserr.TranslateNetstackError(ep.SetSockOpt(&opt))
+ ep.SocketOptions().SetOutOfBandInline(v != 0)
+ return nil
case linux.SO_NO_CHECK:
if len(optVal) < sizeOfInt32 {
@@ -1968,7 +1877,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0))
+ ep.SocketOptions().SetNoChecksum(v != 0)
+ return nil
case linux.SO_LINGER:
if len(optVal) < linux.SizeOfLinger {
@@ -1982,10 +1892,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
- return syserr.TranslateNetstackError(
- ep.SetSockOpt(&tcpip.LingerOption{
- Enabled: v.OnOff != 0,
- Timeout: time.Second * time.Duration(v.Linger)}))
+ ep.SocketOptions().SetLinger(tcpip.LingerOption{
+ Enabled: v.OnOff != 0,
+ Timeout: time.Second * time.Duration(v.Linger),
+ })
+ return nil
case linux.SO_DETACH_FILTER:
// optval is ignored.
@@ -2000,7 +1911,12 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
// setSockOptTCP implements SetSockOpt when level is SOL_TCP.
-func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) {
+ log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto)
+ return syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.TCP_NODELAY:
if len(optVal) < sizeOfInt32 {
@@ -2008,7 +1924,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0))
+ ep.SocketOptions().SetDelayOption(v == 0)
+ return nil
case linux.TCP_CORK:
if len(optVal) < sizeOfInt32 {
@@ -2016,7 +1933,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0))
+ ep.SocketOptions().SetCorkOption(v != 0)
+ return nil
case linux.TCP_QUICKACK:
if len(optVal) < sizeOfInt32 {
@@ -2024,7 +1942,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0))
+ ep.SocketOptions().SetQuickAck(v != 0)
+ return nil
case linux.TCP_MAXSEG:
if len(optVal) < sizeOfInt32 {
@@ -2136,14 +2055,31 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
// setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6.
func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return syserr.ErrUnknownProtocolOption
+ }
+
+ family, skType, skProto := s.Type()
+ if family != linux.AF_INET6 {
+ return syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IPV6_V6ONLY:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
}
+ if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial {
+ return syserr.ErrInvalidEndpointState
+ } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial {
+ return syserr.ErrInvalidEndpointState
+ }
+
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0))
+ ep.SocketOptions().SetV6Only(v != 0)
+ return nil
case linux.IPV6_ADD_MEMBERSHIP,
linux.IPV6_DROP_MEMBERSHIP,
@@ -2163,6 +2099,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
t.Kernel().EmitUnimplementedEvent(t)
+ case linux.IPV6_RECVORIGDSTADDR:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+
+ ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
+ return nil
+
case linux.IPV6_TCLASS:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -2182,7 +2127,8 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0))
+ ep.SocketOptions().SetReceiveTClass(v != 0)
+ return nil
case linux.IP6T_SO_SET_REPLACE:
if len(optVal) < linux.SizeOfIP6TReplace {
@@ -2190,7 +2136,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return syserr.ErrProtocolNotAvailable
}
@@ -2265,6 +2211,11 @@ func parseIntOrChar(buf []byte) (int32, *syserr.Error) {
// setSockOptIP implements SetSockOpt when level is SOL_IP.
func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IP_MULTICAST_TTL:
v, err := parseIntOrChar(optVal)
@@ -2317,7 +2268,7 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.MulticastInterfaceOption{
NIC: tcpip.NICID(req.InterfaceIndex),
- InterfaceAddr: bytesToIPAddress(req.InterfaceAddr[:]),
+ InterfaceAddr: socket.BytesToIPAddress(req.InterfaceAddr[:]),
}))
case linux.IP_MULTICAST_LOOP:
@@ -2326,7 +2277,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0))
+ ep.SocketOptions().SetMulticastLoop(v != 0)
+ return nil
case linux.MCAST_JOIN_GROUP:
// FIXME(b/124219304): Implement MCAST_JOIN_GROUP.
@@ -2362,7 +2314,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0))
+ ep.SocketOptions().SetReceiveTOS(v != 0)
+ return nil
case linux.IP_PKTINFO:
if len(optVal) == 0 {
@@ -2372,7 +2325,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0))
+ ep.SocketOptions().SetReceivePacketInfo(v != 0)
+ return nil
case linux.IP_HDRINCL:
if len(optVal) == 0 {
@@ -2382,7 +2336,20 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0))
+ ep.SocketOptions().SetHeaderIncluded(v != 0)
+ return nil
+
+ case linux.IP_RECVORIGDSTADDR:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
+ return nil
case linux.IPT_SO_SET_REPLACE:
if len(optVal) < linux.SizeOfIPTReplace {
@@ -2422,7 +2389,6 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
linux.IP_RECVERR,
linux.IP_RECVFRAGSIZE,
linux.IP_RECVOPTS,
- linux.IP_RECVORIGDSTADDR,
linux.IP_RECVTTL,
linux.IP_RETOPTS,
linux.IP_TRANSPARENT,
@@ -2500,7 +2466,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
linux.IPV6_RECVFRAGSIZE,
linux.IPV6_RECVHOPLIMIT,
linux.IPV6_RECVHOPOPTS,
- linux.IPV6_RECVORIGDSTADDR,
linux.IPV6_RECVPATHMTU,
linux.IPV6_RECVPKTINFO,
linux.IPV6_RECVRTHDR,
@@ -2524,7 +2489,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) {
switch name {
case linux.IP_TOS,
linux.IP_TTL,
- linux.IP_HDRINCL,
linux.IP_OPTIONS,
linux.IP_ROUTER_ALERT,
linux.IP_RECVOPTS,
@@ -2571,72 +2535,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) {
}
}
-// isLinkLocal determines if the given IPv6 address is link-local. This is the
-// case when it has the fe80::/10 prefix. This check is used to determine when
-// the NICID is relevant for a given IPv6 address.
-func isLinkLocal(addr tcpip.Address) bool {
- return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80
-}
-
-// ConvertAddress converts the given address to a native format.
-func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) {
- switch family {
- case linux.AF_UNIX:
- var out linux.SockAddrUnix
- out.Family = linux.AF_UNIX
- l := len([]byte(addr.Addr))
- for i := 0; i < l; i++ {
- out.Path[i] = int8(addr.Addr[i])
- }
-
- // Linux returns the used length of the address struct (including the
- // null terminator) for filesystem paths. The Family field is 2 bytes.
- // It is sometimes allowed to exclude the null terminator if the
- // address length is the max. Abstract and empty paths always return
- // the full exact length.
- if l == 0 || out.Path[0] == 0 || l == len(out.Path) {
- return &out, uint32(2 + l)
- }
- return &out, uint32(3 + l)
-
- case linux.AF_INET:
- var out linux.SockAddrInet
- copy(out.Addr[:], addr.Addr)
- out.Family = linux.AF_INET
- out.Port = htons(addr.Port)
- return &out, uint32(sockAddrInetSize)
-
- case linux.AF_INET6:
- var out linux.SockAddrInet6
- if len(addr.Addr) == header.IPv4AddressSize {
- // Copy address in v4-mapped format.
- copy(out.Addr[12:], addr.Addr)
- out.Addr[10] = 0xff
- out.Addr[11] = 0xff
- } else {
- copy(out.Addr[:], addr.Addr)
- }
- out.Family = linux.AF_INET6
- out.Port = htons(addr.Port)
- if isLinkLocal(addr.Addr) {
- out.Scope_id = uint32(addr.NIC)
- }
- return &out, uint32(sockAddrInet6Size)
-
- case linux.AF_PACKET:
- // TODO(gvisor.dev/issue/173): Return protocol too.
- var out linux.SockAddrLink
- out.Family = linux.AF_PACKET
- out.InterfaceIndex = int32(addr.NIC)
- out.HardwareAddrLen = header.EthernetAddressSize
- copy(out.HardwareAddr[:], addr.Addr)
- return &out, uint32(sockAddrLinkSize)
-
- default:
- return nil, 0
- }
-}
-
// GetSockName implements the linux syscall getsockname(2) for sockets backed by
// tcpip.Endpoint.
func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
@@ -2645,7 +2543,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := ConvertAddress(s.family, addr)
+ a, l := socket.ConvertAddress(s.family, addr)
return a, l, nil
}
@@ -2657,7 +2555,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := ConvertAddress(s.family, addr)
+ a, l := socket.ConvertAddress(s.family, addr)
return a, l, nil
}
@@ -2675,7 +2573,7 @@ func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequ
// Always do at least one fetchReadView, even if the number of bytes to
// read is 0.
err = s.fetchReadView()
- if err != nil {
+ if err != nil || len(s.readView) == 0 {
break
}
if dst.NumBytes() == 0 {
@@ -2698,15 +2596,20 @@ func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequ
}
copied += n
s.readView.TrimFront(n)
- if len(s.readView) == 0 {
- atomic.StoreUint32(&s.readViewHasData, 0)
- }
dst = dst.DropFirst(n)
if e != nil {
err = syserr.FromError(e)
break
}
+ // If we are done reading requested data then stop.
+ if dst.NumBytes() == 0 {
+ break
+ }
+ }
+
+ if len(s.readView) == 0 {
+ atomic.StoreUint32(&s.readViewHasData, 0)
}
// If we managed to copy something, we must deliver it.
@@ -2801,10 +2704,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
var addr linux.SockAddr
var addrLen uint32
if isPacket && senderRequested {
- addr, addrLen = ConvertAddress(s.family, s.sender)
+ addr, addrLen = socket.ConvertAddress(s.family, s.sender)
switch v := addr.(type) {
case *linux.SockAddrLink:
- v.Protocol = htons(uint16(s.linkPacketInfo.Protocol))
+ v.Protocol = socket.Htons(uint16(s.linkPacketInfo.Protocol))
v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType)
}
}
@@ -2822,7 +2725,7 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
// We need to peek beyond the first message.
dst = dst.DropFirst(n)
num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) {
- n, _, err := s.Endpoint.Peek(dsts)
+ n, err := s.Endpoint.Peek(dsts)
// TODO(b/78348848): Handle peek timestamp.
if err != nil {
return int64(n), syserr.TranslateNetstackError(err).ToError()
@@ -2866,15 +2769,16 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
return socket.ControlMessages{
- IP: tcpip.ControlMessages{
- HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
- Timestamp: s.readCM.Timestamp,
- HasTOS: s.readCM.HasTOS,
- TOS: s.readCM.TOS,
- HasTClass: s.readCM.HasTClass,
- TClass: s.readCM.TClass,
- HasIPPacketInfo: s.readCM.HasIPPacketInfo,
- PacketInfo: s.readCM.PacketInfo,
+ IP: socket.IPControlMessages{
+ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
+ Timestamp: s.readCM.Timestamp,
+ HasTOS: s.readCM.HasTOS,
+ TOS: s.readCM.TOS,
+ HasTClass: s.readCM.HasTClass,
+ TClass: s.readCM.TClass,
+ HasIPPacketInfo: s.readCM.HasIPPacketInfo,
+ PacketInfo: s.readCM.PacketInfo,
+ OriginalDstAddress: s.readCM.OriginalDstAddress,
},
}
}
@@ -2969,7 +2873,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
var addr *tcpip.FullAddress
if len(to) > 0 {
- addrBuf, family, err := AddressAndFamily(to)
+ addrBuf, family, err := socket.AddressAndFamily(to)
if err != nil {
return 0, err
}
@@ -3388,6 +3292,18 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 {
return rv
}
+func isTCPSocket(skType linux.SockType, skProto int) bool {
+ return skType == linux.SOCK_STREAM && (skProto == 0 || skProto == syscall.IPPROTO_TCP)
+}
+
+func isUDPSocket(skType linux.SockType, skProto int) bool {
+ return skType == linux.SOCK_DGRAM && (skProto == 0 || skProto == syscall.IPPROTO_UDP)
+}
+
+func isICMPSocket(skType linux.SockType, skProto int) bool {
+ return skType == linux.SOCK_DGRAM && (skProto == syscall.IPPROTO_ICMP || skProto == syscall.IPPROTO_ICMPV6)
+}
+
// State implements socket.Socket.State. State translates the internal state
// returned by netstack to values defined by Linux.
func (s *socketOpsCommon) State() uint32 {
@@ -3397,7 +3313,7 @@ func (s *socketOpsCommon) State() uint32 {
}
switch {
- case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP:
+ case isTCPSocket(s.skType, s.protocol):
// TCP socket.
switch tcp.EndpointState(s.Endpoint.State()) {
case tcp.StateEstablished:
@@ -3426,7 +3342,7 @@ func (s *socketOpsCommon) State() uint32 {
// Internal or unknown state.
return 0
}
- case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP:
+ case isUDPSocket(s.skType, s.protocol):
// UDP socket.
switch udp.EndpointState(s.Endpoint.State()) {
case udp.StateInitial, udp.StateBound, udp.StateClosed:
@@ -3436,7 +3352,7 @@ func (s *socketOpsCommon) State() uint32 {
default:
return 0
}
- case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6:
+ case isICMPSocket(s.skType, s.protocol):
// TODO(b/112063468): Export states for ICMP sockets.
case s.skType == linux.SOCK_RAW:
// TODO(b/112063468): Export states for raw sockets.
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index 4c6791fff..b756bfca0 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -35,6 +35,8 @@ import (
// SocketVFS2 encapsulates all the state needed to represent a network stack
// endpoint in the kernel context.
+//
+// +stateify savable
type SocketVFS2 struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
@@ -49,13 +51,11 @@ var _ = socket.SocketVFS2(&SocketVFS2{})
// NewVFS2 creates a new endpoint socket.
func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*vfs.FileDescription, *syserr.Error) {
if skType == linux.SOCK_STREAM {
- if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ endpoint.SocketOptions().SetDelayOption(true)
}
mnt := t.Kernel().SocketMount()
- d := sockfs.NewDentry(t.Credentials(), mnt)
+ d := sockfs.NewDentry(t, mnt)
defer d.DecRef(t)
s := &SocketVFS2{
@@ -189,7 +189,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
var addrLen uint32
if peerAddr != nil {
// Get address of the peer and write it to peer slice.
- addr, addrLen = ConvertAddress(s.family, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(s.family, *peerAddr)
}
fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go
index ead3b2b79..c847ff1c7 100644
--- a/pkg/sentry/socket/netstack/provider.go
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -158,7 +158,7 @@ func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol
// protocol is passed in network byte order, but netstack wants it in
// host order.
- netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol)))
+ netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol)))
wq := &waiter.Queue{}
ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq)
diff --git a/pkg/sentry/socket/netstack/provider_vfs2.go b/pkg/sentry/socket/netstack/provider_vfs2.go
index 2a01143f6..0af805246 100644
--- a/pkg/sentry/socket/netstack/provider_vfs2.go
+++ b/pkg/sentry/socket/netstack/provider_vfs2.go
@@ -102,7 +102,7 @@ func packetSocketVFS2(t *kernel.Task, epStack *Stack, stype linux.SockType, prot
// protocol is passed in network byte order, but netstack wants it in
// host order.
- netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol)))
+ netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol)))
wq := &waiter.Queue{}
ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq)
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index 1028d2a6e..cc0fadeb5 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -100,56 +100,101 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
return nicAddrs
}
-// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
-func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+// convertAddr converts an InterfaceAddr to a ProtocolAddress.
+func convertAddr(addr inet.InterfaceAddr) (tcpip.ProtocolAddress, error) {
var (
- protocol tcpip.NetworkProtocolNumber
- address tcpip.Address
+ protocol tcpip.NetworkProtocolNumber
+ address tcpip.Address
+ protocolAddress tcpip.ProtocolAddress
)
switch addr.Family {
case linux.AF_INET:
- if len(addr.Addr) < header.IPv4AddressSize {
- return syserror.EINVAL
+ if len(addr.Addr) != header.IPv4AddressSize {
+ return protocolAddress, syserror.EINVAL
}
if addr.PrefixLen > header.IPv4AddressSize*8 {
- return syserror.EINVAL
+ return protocolAddress, syserror.EINVAL
}
protocol = ipv4.ProtocolNumber
- address = tcpip.Address(addr.Addr[:header.IPv4AddressSize])
-
+ address = tcpip.Address(addr.Addr)
case linux.AF_INET6:
- if len(addr.Addr) < header.IPv6AddressSize {
- return syserror.EINVAL
+ if len(addr.Addr) != header.IPv6AddressSize {
+ return protocolAddress, syserror.EINVAL
}
if addr.PrefixLen > header.IPv6AddressSize*8 {
- return syserror.EINVAL
+ return protocolAddress, syserror.EINVAL
}
protocol = ipv6.ProtocolNumber
- address = tcpip.Address(addr.Addr[:header.IPv6AddressSize])
-
+ address = tcpip.Address(addr.Addr)
default:
- return syserror.ENOTSUP
+ return protocolAddress, syserror.ENOTSUP
}
- protocolAddress := tcpip.ProtocolAddress{
+ protocolAddress = tcpip.ProtocolAddress{
Protocol: protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: address,
PrefixLen: int(addr.PrefixLen),
},
}
+ return protocolAddress, nil
+}
+
+// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
+func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+ protocolAddress, err := convertAddr(addr)
+ if err != nil {
+ return err
+ }
// Attach address to interface.
- if err := s.Stack.AddProtocolAddressWithOptions(tcpip.NICID(idx), protocolAddress, stack.CanBePrimaryEndpoint); err != nil {
+ nicID := tcpip.NICID(idx)
+ if err := s.Stack.AddProtocolAddressWithOptions(nicID, protocolAddress, stack.CanBePrimaryEndpoint); err != nil {
+ return syserr.TranslateNetstackError(err).ToError()
+ }
+
+ // Add route for local network if it doesn't exist already.
+ localRoute := tcpip.Route{
+ Destination: protocolAddress.AddressWithPrefix.Subnet(),
+ Gateway: "", // No gateway for local network.
+ NIC: nicID,
+ }
+
+ for _, rt := range s.Stack.GetRouteTable() {
+ if rt.Equal(localRoute) {
+ return nil
+ }
+ }
+
+ // Local route does not exist yet. Add it.
+ s.Stack.AddRoute(localRoute)
+
+ return nil
+}
+
+// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr.
+func (s *Stack) RemoveInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+ protocolAddress, err := convertAddr(addr)
+ if err != nil {
+ return err
+ }
+
+ // Remove addresses matching the address and prefix.
+ nicID := tcpip.NICID(idx)
+ if err := s.Stack.RemoveAddress(nicID, protocolAddress.AddressWithPrefix.Address); err != nil {
return syserr.TranslateNetstackError(err).ToError()
}
- // Add route for local network.
- s.Stack.AddRoute(tcpip.Route{
+ // Remove the corresponding local network route if it exists.
+ localRoute := tcpip.Route{
Destination: protocolAddress.AddressWithPrefix.Subnet(),
Gateway: "", // No gateway for local network.
- NIC: tcpip.NICID(idx),
+ NIC: nicID,
+ }
+ s.Stack.RemoveRoutes(func(rt tcpip.Route) bool {
+ return rt.Equal(localRoute)
})
+
return nil
}
@@ -279,12 +324,12 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
0, // Support Ip/FragCreates.
}
case *inet.StatSNMPICMP:
- in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats
- out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats
+ in := Metrics.ICMP.V4.PacketsReceived.ICMPv4PacketStats
+ out := Metrics.ICMP.V4.PacketsSent.ICMPv4PacketStats
// TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPICMP{
0, // Icmp/InMsgs.
- Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors.
+ Metrics.ICMP.V4.PacketsSent.Dropped.Value(), // InErrors.
0, // Icmp/InCsumErrors.
in.DstUnreachable.Value(), // InDestUnreachs.
in.TimeExceeded.Value(), // InTimeExcds.
@@ -298,18 +343,18 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
in.InfoRequest.Value(), // InAddrMasks.
in.InfoReply.Value(), // InAddrMaskReps.
0, // Icmp/OutMsgs.
- Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors.
- out.DstUnreachable.Value(), // OutDestUnreachs.
- out.TimeExceeded.Value(), // OutTimeExcds.
- out.ParamProblem.Value(), // OutParmProbs.
- out.SrcQuench.Value(), // OutSrcQuenchs.
- out.Redirect.Value(), // OutRedirects.
- out.Echo.Value(), // OutEchos.
- out.EchoReply.Value(), // OutEchoReps.
- out.Timestamp.Value(), // OutTimestamps.
- out.TimestampReply.Value(), // OutTimestampReps.
- out.InfoRequest.Value(), // OutAddrMasks.
- out.InfoReply.Value(), // OutAddrMaskReps.
+ Metrics.ICMP.V4.PacketsReceived.Invalid.Value(), // OutErrors.
+ out.DstUnreachable.Value(), // OutDestUnreachs.
+ out.TimeExceeded.Value(), // OutTimeExcds.
+ out.ParamProblem.Value(), // OutParmProbs.
+ out.SrcQuench.Value(), // OutSrcQuenchs.
+ out.Redirect.Value(), // OutRedirects.
+ out.Echo.Value(), // OutEchos.
+ out.EchoReply.Value(), // OutEchoReps.
+ out.Timestamp.Value(), // OutTimestamps.
+ out.TimestampReply.Value(), // OutTimestampReps.
+ out.InfoRequest.Value(), // OutAddrMasks.
+ out.InfoReply.Value(), // OutAddrMaskReps.
}
case *inet.StatSNMPTCP:
tcp := Metrics.TCP
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index fd31479e5..bcc426e33 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -18,6 +18,7 @@
package socket
import (
+ "bytes"
"fmt"
"sync/atomic"
"syscall"
@@ -35,6 +36,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -42,7 +44,79 @@ import (
// control messages.
type ControlMessages struct {
Unix transport.ControlMessages
- IP tcpip.ControlMessages
+ IP IPControlMessages
+}
+
+// packetInfoToLinux converts IPPacketInfo from tcpip format to Linux format.
+func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPacketInfo {
+ var p linux.ControlMessageIPPacketInfo
+ p.NIC = int32(packetInfo.NIC)
+ copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
+ copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
+ return p
+}
+
+// NewIPControlMessages converts the tcpip ControlMessgaes (which does not
+// have Linux specific format) to Linux format.
+func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessages {
+ var orgDstAddr linux.SockAddr
+ if cmgs.HasOriginalDstAddress {
+ orgDstAddr, _ = ConvertAddress(family, cmgs.OriginalDstAddress)
+ }
+ return IPControlMessages{
+ HasTimestamp: cmgs.HasTimestamp,
+ Timestamp: cmgs.Timestamp,
+ HasInq: cmgs.HasInq,
+ Inq: cmgs.Inq,
+ HasTOS: cmgs.HasTOS,
+ TOS: cmgs.TOS,
+ HasTClass: cmgs.HasTClass,
+ TClass: cmgs.TClass,
+ HasIPPacketInfo: cmgs.HasIPPacketInfo,
+ PacketInfo: packetInfoToLinux(cmgs.PacketInfo),
+ OriginalDstAddress: orgDstAddr,
+ }
+}
+
+// IPControlMessages contains socket control messages for IP sockets.
+// This can contain Linux specific structures unlike tcpip.ControlMessages.
+//
+// +stateify savable
+type IPControlMessages struct {
+ // HasTimestamp indicates whether Timestamp is valid/set.
+ HasTimestamp bool
+
+ // Timestamp is the time (in ns) that the last packet used to create
+ // the read data was received.
+ Timestamp int64
+
+ // HasInq indicates whether Inq is valid/set.
+ HasInq bool
+
+ // Inq is the number of bytes ready to be received.
+ Inq int32
+
+ // HasTOS indicates whether Tos is valid/set.
+ HasTOS bool
+
+ // TOS is the IPv4 type of service of the associated packet.
+ TOS uint8
+
+ // HasTClass indicates whether TClass is valid/set.
+ HasTClass bool
+
+ // TClass is the IPv6 traffic class of the associated packet.
+ TClass uint32
+
+ // HasIPPacketInfo indicates whether PacketInfo is set.
+ HasIPPacketInfo bool
+
+ // PacketInfo holds interface and address data on an incoming packet.
+ PacketInfo linux.ControlMessageIPPacketInfo
+
+ // OriginalDestinationAddress holds the original destination address
+ // and port of the incoming packet.
+ OriginalDstAddress linux.SockAddr
}
// Release releases Unix domain socket credentials and rights.
@@ -460,3 +534,176 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr {
panic(fmt.Sprintf("Unsupported socket family %v", family))
}
}
+
+var sockAddrLinkSize = (&linux.SockAddrLink{}).SizeBytes()
+var sockAddrInetSize = (&linux.SockAddrInet{}).SizeBytes()
+var sockAddrInet6Size = (&linux.SockAddrInet6{}).SizeBytes()
+
+// Ntohs converts a 16-bit number from network byte order to host byte order. It
+// assumes that the host is little endian.
+func Ntohs(v uint16) uint16 {
+ return v<<8 | v>>8
+}
+
+// Htons converts a 16-bit number from host byte order to network byte order. It
+// assumes that the host is little endian.
+func Htons(v uint16) uint16 {
+ return Ntohs(v)
+}
+
+// isLinkLocal determines if the given IPv6 address is link-local. This is the
+// case when it has the fe80::/10 prefix. This check is used to determine when
+// the NICID is relevant for a given IPv6 address.
+func isLinkLocal(addr tcpip.Address) bool {
+ return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80
+}
+
+// ConvertAddress converts the given address to a native format.
+func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) {
+ switch family {
+ case linux.AF_UNIX:
+ var out linux.SockAddrUnix
+ out.Family = linux.AF_UNIX
+ l := len([]byte(addr.Addr))
+ for i := 0; i < l; i++ {
+ out.Path[i] = int8(addr.Addr[i])
+ }
+
+ // Linux returns the used length of the address struct (including the
+ // null terminator) for filesystem paths. The Family field is 2 bytes.
+ // It is sometimes allowed to exclude the null terminator if the
+ // address length is the max. Abstract and empty paths always return
+ // the full exact length.
+ if l == 0 || out.Path[0] == 0 || l == len(out.Path) {
+ return &out, uint32(2 + l)
+ }
+ return &out, uint32(3 + l)
+
+ case linux.AF_INET:
+ var out linux.SockAddrInet
+ copy(out.Addr[:], addr.Addr)
+ out.Family = linux.AF_INET
+ out.Port = Htons(addr.Port)
+ return &out, uint32(sockAddrInetSize)
+
+ case linux.AF_INET6:
+ var out linux.SockAddrInet6
+ if len(addr.Addr) == header.IPv4AddressSize {
+ // Copy address in v4-mapped format.
+ copy(out.Addr[12:], addr.Addr)
+ out.Addr[10] = 0xff
+ out.Addr[11] = 0xff
+ } else {
+ copy(out.Addr[:], addr.Addr)
+ }
+ out.Family = linux.AF_INET6
+ out.Port = Htons(addr.Port)
+ if isLinkLocal(addr.Addr) {
+ out.Scope_id = uint32(addr.NIC)
+ }
+ return &out, uint32(sockAddrInet6Size)
+
+ case linux.AF_PACKET:
+ // TODO(gvisor.dev/issue/173): Return protocol too.
+ var out linux.SockAddrLink
+ out.Family = linux.AF_PACKET
+ out.InterfaceIndex = int32(addr.NIC)
+ out.HardwareAddrLen = header.EthernetAddressSize
+ copy(out.HardwareAddr[:], addr.Addr)
+ return &out, uint32(sockAddrLinkSize)
+
+ default:
+ return nil, 0
+ }
+}
+
+// BytesToIPAddress converts an IPv4 or IPv6 address from the user to the
+// netstack representation taking any addresses into account.
+func BytesToIPAddress(addr []byte) tcpip.Address {
+ if bytes.Equal(addr, make([]byte, 4)) || bytes.Equal(addr, make([]byte, 16)) {
+ return ""
+ }
+ return tcpip.Address(addr)
+}
+
+// AddressAndFamily reads an sockaddr struct from the given address and
+// converts it to the FullAddress format. It supports AF_UNIX, AF_INET,
+// AF_INET6, and AF_PACKET addresses.
+//
+// AddressAndFamily returns an address and its family.
+func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
+ // Make sure we have at least 2 bytes for the address family.
+ if len(addr) < 2 {
+ return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument
+ }
+
+ // Get the rest of the fields based on the address family.
+ switch family := usermem.ByteOrder.Uint16(addr); family {
+ case linux.AF_UNIX:
+ path := addr[2:]
+ if len(path) > linux.UnixPathMax {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ // Drop the terminating NUL (if one exists) and everything after
+ // it for filesystem (non-abstract) addresses.
+ if len(path) > 0 && path[0] != 0 {
+ if n := bytes.IndexByte(path[1:], 0); n >= 0 {
+ path = path[:n+1]
+ }
+ }
+ return tcpip.FullAddress{
+ Addr: tcpip.Address(path),
+ }, family, nil
+
+ case linux.AF_INET:
+ var a linux.SockAddrInet
+ if len(addr) < sockAddrInetSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
+
+ out := tcpip.FullAddress{
+ Addr: BytesToIPAddress(a.Addr[:]),
+ Port: Ntohs(a.Port),
+ }
+ return out, family, nil
+
+ case linux.AF_INET6:
+ var a linux.SockAddrInet6
+ if len(addr) < sockAddrInet6Size {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
+
+ out := tcpip.FullAddress{
+ Addr: BytesToIPAddress(a.Addr[:]),
+ Port: Ntohs(a.Port),
+ }
+ if isLinkLocal(out.Addr) {
+ out.NIC = tcpip.NICID(a.Scope_id)
+ }
+ return out, family, nil
+
+ case linux.AF_PACKET:
+ var a linux.SockAddrLink
+ if len(addr) < sockAddrLinkSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a)
+ if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/173): Return protocol too.
+ return tcpip.FullAddress{
+ NIC: tcpip.NICID(a.InterfaceIndex),
+ Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ }, family, nil
+
+ case linux.AF_UNSPEC:
+ return tcpip.FullAddress{}, family, nil
+
+ default:
+ return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported
+ }
+}
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index cc7408698..cce0acc33 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "socket_refs.go",
package = "unix",
prefix = "socketOperations",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "SocketOperations",
},
@@ -19,7 +19,7 @@ go_template_instance(
out = "socket_vfs2_refs.go",
package = "unix",
prefix = "socketVFS2",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "SocketVFS2",
},
@@ -43,6 +43,7 @@ go_library(
"//pkg/log",
"//pkg/marshal",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index 26c3a51b9..3ebbd28b0 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -20,7 +20,7 @@ go_template_instance(
out = "queue_refs.go",
package = "transport",
prefix = "queue",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "queue",
},
@@ -44,6 +44,7 @@ go_library(
"//pkg/ilist",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/sync",
"//pkg/syserr",
"//pkg/tcpip",
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index aa4f3c04d..9f7aca305 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -118,33 +118,29 @@ var (
// NewConnectioned creates a new unbound connectionedEndpoint.
func NewConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) Endpoint {
- return &connectionedEndpoint{
+ return newConnectioned(ctx, stype, uid)
+}
+
+func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) *connectionedEndpoint {
+ ep := &connectionedEndpoint{
baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
id: uid.UniqueID(),
idGenerator: uid,
stype: stype,
}
+ ep.ops.InitHandler(ep)
+ return ep
}
// NewPair allocates a new pair of connected unix-domain connectionedEndpoints.
func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
- a := &connectionedEndpoint{
- baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
- id: uid.UniqueID(),
- idGenerator: uid,
- stype: stype,
- }
- b := &connectionedEndpoint{
- baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
- id: uid.UniqueID(),
- idGenerator: uid,
- stype: stype,
- }
+ a := newConnectioned(ctx, stype, uid)
+ b := newConnectioned(ctx, stype, uid)
q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit}
- q1.EnableLeakCheck()
+ q1.InitRefs()
q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit}
- q2.EnableLeakCheck()
+ q2.InitRefs()
if stype == linux.SOCK_STREAM {
a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}}
@@ -171,12 +167,14 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E
// NewExternal creates a new externally backed Endpoint. It behaves like a
// socketpair.
func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint {
- return &connectionedEndpoint{
+ ep := &connectionedEndpoint{
baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected},
id: uid.UniqueID(),
idGenerator: uid,
stype: stype,
}
+ ep.ops.InitHandler(ep)
+ return ep
}
// ID implements ConnectingEndpoint.ID.
@@ -298,16 +296,17 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn
idGenerator: e.idGenerator,
stype: e.stype,
}
+ ne.ops.InitHandler(ne)
readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit}
- readQueue.EnableLeakCheck()
+ readQueue.InitRefs()
ne.connected = &connectedEndpoint{
endpoint: ce,
writeQueue: readQueue,
}
writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit}
- writeQueue.EnableLeakCheck()
+ writeQueue.InitRefs()
if e.stype == linux.SOCK_STREAM {
ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}}
} else {
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index f8aacca13..0813ad87d 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -42,8 +42,9 @@ var (
func NewConnectionless(ctx context.Context) Endpoint {
ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}}
q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit}
- q.EnableLeakCheck()
+ q.InitRefs()
ep.receiver = &queueReceiver{readQueue: &q}
+ ep.ops.InitHandler(ep)
return ep
}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index d6fc03520..099a56281 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -16,8 +16,6 @@
package transport
import (
- "sync/atomic"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
@@ -32,6 +30,8 @@ import (
const initialLimit = 16 * 1024
// A RightsControlMessage is a control message containing FDs.
+//
+// +stateify savable
type RightsControlMessage interface {
// Clone returns a copy of the RightsControlMessage.
Clone() RightsControlMessage
@@ -178,10 +178,6 @@ type Endpoint interface {
// SetSockOpt sets a socket option.
SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error
- // SetSockOptBool sets a socket option for simple cases when a value has
- // the int type.
- SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
-
// SetSockOptInt sets a socket option for simple cases when a value has
// the int type.
SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
@@ -189,10 +185,6 @@ type Endpoint interface {
// GetSockOpt gets a socket option.
GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error
- // GetSockOptBool gets a socket option for simple cases when a return
- // value has the int type.
- GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
-
// GetSockOptInt gets a socket option for simple cases when a return
// value has the int type.
GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
@@ -201,8 +193,12 @@ type Endpoint interface {
// procfs.
State() uint32
- // LastError implements tcpip.Endpoint.LastError.
+ // LastError clears and returns the last error reported by the endpoint.
LastError() *tcpip.Error
+
+ // SocketOptions returns the structure which contains all the socket
+ // level options.
+ SocketOptions() *tcpip.SocketOptions
}
// A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket
@@ -336,7 +332,7 @@ type Receiver interface {
RecvMaxQueueSize() int64
// Release releases any resources owned by the Receiver. It should be
- // called before droping all references to a Receiver.
+ // called before dropping all references to a Receiver.
Release(ctx context.Context)
}
@@ -487,7 +483,7 @@ func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds
c := q.control.Clone()
// Don't consume data since we are peeking.
- copied, data, _ = vecCopy(data, q.buffer)
+ copied, _, _ = vecCopy(data, q.buffer)
return copied, copied, c, false, q.addr, notify, nil
}
@@ -572,6 +568,12 @@ func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds
return copied, copied, c, cmTruncated, q.addr, notify, nil
}
+// Release implements Receiver.Release.
+func (q *streamQueueReceiver) Release(ctx context.Context) {
+ q.queueReceiver.Release(ctx)
+ q.control.Release(ctx)
+}
+
// A ConnectedEndpoint is an Endpoint that can be used to send Messages.
type ConnectedEndpoint interface {
// Passcred implements Endpoint.Passcred.
@@ -619,7 +621,7 @@ type ConnectedEndpoint interface {
SendMaxQueueSize() int64
// Release releases any resources owned by the ConnectedEndpoint. It should
- // be called before droping all references to a ConnectedEndpoint.
+ // be called before dropping all references to a ConnectedEndpoint.
Release(ctx context.Context)
// CloseUnread sets the fact that this end is closed with unread data to
@@ -728,10 +730,7 @@ func (e *connectedEndpoint) CloseUnread() {
// +stateify savable
type baseEndpoint struct {
*waiter.Queue
-
- // passcred specifies whether SCM_CREDENTIALS socket control messages are
- // enabled on this endpoint. Must be accessed atomically.
- passcred int32
+ tcpip.DefaultSocketOptionsHandler
// Mutex protects the below fields.
sync.Mutex `state:"nosave"`
@@ -747,8 +746,8 @@ type baseEndpoint struct {
// or may be used if the endpoint is connected.
path string
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
+ // ops is used to get socket level options.
+ ops tcpip.SocketOptions
}
// EventRegister implements waiter.Waitable.EventRegister.
@@ -773,7 +772,7 @@ func (e *baseEndpoint) EventUnregister(we *waiter.Entry) {
// Passcred implements Credentialer.Passcred.
func (e *baseEndpoint) Passcred() bool {
- return atomic.LoadInt32(&e.passcred) != 0
+ return e.SocketOptions().GetPassCred()
}
// ConnectedPasscred implements Credentialer.ConnectedPasscred.
@@ -783,14 +782,6 @@ func (e *baseEndpoint) ConnectedPasscred() bool {
return e.connected != nil && e.connected.Passcred()
}
-func (e *baseEndpoint) setPasscred(pc bool) {
- if pc {
- atomic.StoreInt32(&e.passcred, 1)
- } else {
- atomic.StoreInt32(&e.passcred, 0)
- }
-}
-
// Connected implements ConnectingEndpoint.Connected.
func (e *baseEndpoint) Connected() bool {
return e.receiver != nil && e.connected != nil
@@ -846,24 +837,6 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess
// SetSockOpt sets a socket option.
func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
- case *tcpip.LingerOption:
- e.Lock()
- e.linger = *v
- e.Unlock()
- }
- return nil
-}
-
-func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- switch opt {
- case tcpip.BroadcastOption:
- case tcpip.PasscredOption:
- e.setPasscred(v)
- case tcpip.ReuseAddressOption:
- default:
- log.Warningf("Unsupported socket option: %d", opt)
- }
return nil
}
@@ -877,20 +850,6 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
-func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
- case tcpip.KeepaliveEnabledOption:
- return false, nil
-
- case tcpip.PasscredOption:
- return e.Passcred(), nil
-
- default:
- log.Warningf("Unsupported socket option: %d", opt)
- return false, tcpip.ErrUnknownProtocolOption
- }
-}
-
func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
@@ -954,17 +913,8 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- e.Lock()
- *o = e.linger
- e.Unlock()
- return nil
-
- default:
- log.Warningf("Unsupported socket option: %T", opt)
- return tcpip.ErrUnknownProtocolOption
- }
+ log.Warningf("Unsupported socket option: %T", opt)
+ return tcpip.ErrUnknownProtocolOption
}
// LastError implements Endpoint.LastError.
@@ -972,6 +922,11 @@ func (*baseEndpoint) LastError() *tcpip.Error {
return nil
}
+// SocketOptions implements Endpoint.SocketOptions.
+func (e *baseEndpoint) SocketOptions() *tcpip.SocketOptions {
+ return &e.ops
+}
+
// Shutdown closes the read and/or write end of the endpoint connection to its
// peer.
func (e *baseEndpoint) Shutdown(flags tcpip.ShutdownFlags) *syserr.Error {
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index a4a76d0a3..c59297c80 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -80,8 +80,7 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty
stype: stype,
},
}
- s.EnableLeakCheck()
-
+ s.InitRefs()
return fs.NewFile(ctx, d, flags, &s)
}
@@ -137,7 +136,7 @@ func (s *socketOpsCommon) Endpoint() transport.Endpoint {
// extractPath extracts and validates the address.
func extractPath(sockaddr []byte) (string, *syserr.Error) {
- addr, family, err := netstack.AddressAndFamily(sockaddr)
+ addr, family, err := socket.AddressAndFamily(sockaddr)
if err != nil {
if err == syserr.ErrAddressFamilyNotSupported {
err = syserr.ErrInvalidArgument
@@ -170,7 +169,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := netstack.ConvertAddress(linux.AF_UNIX, addr)
+ a, l := socket.ConvertAddress(linux.AF_UNIX, addr)
return a, l, nil
}
@@ -182,7 +181,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := netstack.ConvertAddress(linux.AF_UNIX, addr)
+ a, l := socket.ConvertAddress(linux.AF_UNIX, addr)
return a, l, nil
}
@@ -256,7 +255,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
var addr linux.SockAddr
var addrLen uint32
if peerAddr != nil {
- addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr)
}
fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
@@ -648,7 +647,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
var from linux.SockAddr
var fromLen uint32
if r.From != nil && len([]byte(r.From.Addr)) != 0 {
- from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From)
+ from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From)
}
if r.ControlTrunc {
@@ -683,7 +682,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
var from linux.SockAddr
var fromLen uint32
if r.From != nil {
- from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From)
+ from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From)
}
if r.ControlTrunc {
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index 678355fb9..27f705bb2 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -55,7 +55,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{})
// returns a corresponding file description.
func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) {
mnt := t.Kernel().SocketMount()
- d := sockfs.NewDentry(t.Credentials(), mnt)
+ d := sockfs.NewDentry(t, mnt)
defer d.DecRef(t)
fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{})
@@ -80,6 +80,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3
stype: stype,
},
}
+ sock.InitRefs()
sock.LockFD.Init(locks)
vfsfd := &sock.vfsfd
if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{
@@ -171,7 +172,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
var addr linux.SockAddr
var addrLen uint32
if peerAddr != nil {
- addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr)
}
fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{