summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/abi/linux/socket.go13
-rw-r--r--pkg/sentry/socket/control/control.go43
-rw-r--r--pkg/sentry/socket/hostinet/socket.go11
-rwxr-xr-xpkg/sentry/socket/netstack/netstack.go37
-rw-r--r--pkg/tcpip/tcpip.go25
-rwxr-xr-xpkg/tcpip/tcpip_state_autogen.go20
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go26
-rwxr-xr-xpkg/tcpip/transport/udp/udp_state_autogen.go4
8 files changed, 173 insertions, 6 deletions
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index 766ee4014..4a14ef691 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -411,6 +411,15 @@ type ControlMessageCredentials struct {
GID uint32
}
+// A ControlMessageIPPacketInfo is IP_PKTINFO socket control message.
+//
+// ControlMessageIPPacketInfo represents struct in_pktinfo from linux/in.h.
+type ControlMessageIPPacketInfo struct {
+ NIC int32
+ LocalAddr InetAddr
+ DestinationAddr InetAddr
+}
+
// SizeOfControlMessageCredentials is the binary size of a
// ControlMessageCredentials struct.
var SizeOfControlMessageCredentials = int(binary.Size(ControlMessageCredentials{}))
@@ -431,6 +440,10 @@ const SizeOfControlMessageTOS = 1
// SizeOfControlMessageTClass is the size of an IPV6_TCLASS control message.
const SizeOfControlMessageTClass = 4
+// SizeOfControlMessageIPPacketInfo is the size of an IP_PKTINFO
+// control message.
+const SizeOfControlMessageIPPacketInfo = 12
+
// SCM_MAX_FD is the maximum number of FDs accepted in a single sendmsg call.
// From net/scm.h.
const SCM_MAX_FD = 253
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 6145a7fc3..4667373d2 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -338,6 +339,22 @@ func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte {
)
}
+// PackIPPacketInfo packs an IP_PKTINFO socket control message.
+func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte {
+ var p linux.ControlMessageIPPacketInfo
+ p.NIC = int32(packetInfo.NIC)
+ copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
+ copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
+
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IP,
+ linux.IP_PKTINFO,
+ t.Arch().Width(),
+ p,
+ )
+}
+
// PackControlMessages packs control messages into the given buffer.
//
// We skip control messages specific to Unix domain sockets.
@@ -362,6 +379,10 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt
buf = PackTClass(t, cmsgs.IP.TClass, buf)
}
+ if cmsgs.IP.HasIPPacketInfo {
+ buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf)
+ }
+
return buf
}
@@ -394,6 +415,16 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
return space
}
+// NewIPPacketInfo returns the IPPacketInfo struct.
+func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo {
+ var p tcpip.IPPacketInfo
+ p.NIC = tcpip.NICID(packetInfo.NIC)
+ copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:])
+ copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:])
+
+ return p
+}
+
// Parse parses a raw socket control message into portable objects.
func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) {
var (
@@ -468,6 +499,18 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS)
i += binary.AlignUp(length, width)
+ case linux.IP_PKTINFO:
+ if length < linux.SizeOfControlMessageIPPacketInfo {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ cmsgs.IP.HasIPPacketInfo = true
+ var packetInfo linux.ControlMessageIPPacketInfo
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
+
+ cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo)
+ i += binary.AlignUp(length, width)
+
default:
return socket.ControlMessages{}, syserror.EINVAL
}
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index de76388ac..22f78d2e2 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -289,7 +289,7 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_TOS, linux.IP_RECVTOS:
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO:
optlen = sizeofInt32
}
case linux.SOL_IPV6:
@@ -336,6 +336,8 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
switch name {
case linux.IP_TOS, linux.IP_RECVTOS:
optlen = sizeofInt32
+ case linux.IP_PKTINFO:
+ optlen = linux.SizeOfControlMessageIPPacketInfo
}
case linux.SOL_IPV6:
switch name {
@@ -473,7 +475,14 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
case syscall.IP_TOS:
controlMessages.IP.HasTOS = true
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS)
+
+ case syscall.IP_PKTINFO:
+ controlMessages.IP.HasIPPacketInfo = true
+ var packetInfo linux.ControlMessageIPPacketInfo
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
+ controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo)
}
+
case syscall.SOL_IPV6:
switch unixCmsg.Header.Type {
case syscall.IPV6_TCLASS:
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index ed2fbcceb..9757fbfba 100755
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1414,6 +1414,21 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
}
return o, nil
+ case linux.IP_PKTINFO:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ var o int32
+ if v {
+ o = 1
+ }
+ return o, nil
+
default:
emitUnimplementedEventIP(t, name)
}
@@ -1762,6 +1777,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
linux.IPV6_IPSEC_POLICY,
linux.IPV6_JOIN_ANYCAST,
linux.IPV6_LEAVE_ANYCAST,
+ // TODO(b/148887420): Add support for IPV6_PKTINFO.
linux.IPV6_PKTINFO,
linux.IPV6_ROUTER_ALERT,
linux.IPV6_XFRM_POLICY,
@@ -1949,6 +1965,16 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
}
return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0))
+ case linux.IP_PKTINFO:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0))
+
case linux.IP_ADD_SOURCE_MEMBERSHIP,
linux.IP_BIND_ADDRESS_NO_PORT,
linux.IP_BLOCK_SOURCE,
@@ -1964,7 +1990,6 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
linux.IP_NODEFRAG,
linux.IP_OPTIONS,
linux.IP_PASSSEC,
- linux.IP_PKTINFO,
linux.IP_RECVERR,
linux.IP_RECVFRAGSIZE,
linux.IP_RECVOPTS,
@@ -2395,10 +2420,12 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
func (s *SocketOperations) controlMessages() socket.ControlMessages {
return socket.ControlMessages{
IP: tcpip.ControlMessages{
- HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
- Timestamp: s.readCM.Timestamp,
- HasTOS: s.readCM.HasTOS,
- TOS: s.readCM.TOS,
+ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
+ Timestamp: s.readCM.Timestamp,
+ HasTOS: s.readCM.HasTOS,
+ TOS: s.readCM.TOS,
+ HasIPPacketInfo: s.readCM.HasIPPacketInfo,
+ PacketInfo: s.readCM.PacketInfo,
},
}
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 0e944712f..9ca39ce40 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -328,6 +328,12 @@ type ControlMessages struct {
// Tclass is the IPv6 traffic class of the associated packet.
TClass int32
+
+ // HasIPPacketInfo indicates whether PacketInfo is set.
+ HasIPPacketInfo bool
+
+ // PacketInfo holds interface and address data on an incoming packet.
+ PacketInfo IPPacketInfo
}
// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp)
@@ -503,6 +509,11 @@ const (
// V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6
// socket is to be restricted to sending and receiving IPv6 packets only.
V6OnlyOption
+
+ // ReceiveIPPacketInfoOption is used by {G,S}etSockOptBool to specify
+ // if more inforamtion is provided with incoming packets such
+ // as interface index and address.
+ ReceiveIPPacketInfoOption
)
// SockOptInt represents socket options which values have the int type.
@@ -685,6 +696,20 @@ type IPv4TOSOption uint8
// for all subsequent outgoing IPv6 packets from the endpoint.
type IPv6TrafficClassOption uint8
+// IPPacketInfo is the message struture for IP_PKTINFO.
+//
+// +stateify savable
+type IPPacketInfo struct {
+ // NIC is the ID of the NIC to be used.
+ NIC NICID
+
+ // LocalAddr is the local address.
+ LocalAddr Address
+
+ // DestinationAddr is the destination address.
+ DestinationAddr Address
+}
+
// Route is a row in the routing table. It specifies through which NIC (and
// gateway) sets of packets should be routed. A row is considered viable if the
// masked target address matches the destination address in the row.
diff --git a/pkg/tcpip/tcpip_state_autogen.go b/pkg/tcpip/tcpip_state_autogen.go
index 688c1d777..272b21d87 100755
--- a/pkg/tcpip/tcpip_state_autogen.go
+++ b/pkg/tcpip/tcpip_state_autogen.go
@@ -57,6 +57,8 @@ func (x *ControlMessages) save(m state.Map) {
m.Save("TOS", &x.TOS)
m.Save("HasTClass", &x.HasTClass)
m.Save("TClass", &x.TClass)
+ m.Save("HasIPPacketInfo", &x.HasIPPacketInfo)
+ m.Save("PacketInfo", &x.PacketInfo)
}
func (x *ControlMessages) afterLoad() {}
@@ -69,10 +71,28 @@ func (x *ControlMessages) load(m state.Map) {
m.Load("TOS", &x.TOS)
m.Load("HasTClass", &x.HasTClass)
m.Load("TClass", &x.TClass)
+ m.Load("HasIPPacketInfo", &x.HasIPPacketInfo)
+ m.Load("PacketInfo", &x.PacketInfo)
+}
+
+func (x *IPPacketInfo) beforeSave() {}
+func (x *IPPacketInfo) save(m state.Map) {
+ x.beforeSave()
+ m.Save("NIC", &x.NIC)
+ m.Save("LocalAddr", &x.LocalAddr)
+ m.Save("DestinationAddr", &x.DestinationAddr)
+}
+
+func (x *IPPacketInfo) afterLoad() {}
+func (x *IPPacketInfo) load(m state.Map) {
+ m.Load("NIC", &x.NIC)
+ m.Load("LocalAddr", &x.LocalAddr)
+ m.Load("DestinationAddr", &x.DestinationAddr)
}
func init() {
state.Register("pkg/tcpip.PacketBuffer", (*PacketBuffer)(nil), state.Fns{Save: (*PacketBuffer).save, Load: (*PacketBuffer).load})
state.Register("pkg/tcpip.FullAddress", (*FullAddress)(nil), state.Fns{Save: (*FullAddress).save, Load: (*FullAddress).load})
state.Register("pkg/tcpip.ControlMessages", (*ControlMessages)(nil), state.Fns{Save: (*ControlMessages).save, Load: (*ControlMessages).load})
+ state.Register("pkg/tcpip.IPPacketInfo", (*IPPacketInfo)(nil), state.Fns{Save: (*IPPacketInfo).save, Load: (*IPPacketInfo).load})
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index c9cbed8f4..3fe91cac2 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -29,6 +29,7 @@ import (
type udpPacket struct {
udpPacketEntry
senderAddress tcpip.FullAddress
+ packetInfo tcpip.IPPacketInfo
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
timestamp int64
tos uint8
@@ -118,6 +119,9 @@ type endpoint struct {
// as ancillary data to ControlMessages on Read.
receiveTOS bool
+ // receiveIPPacketInfo determines if the packet info is returned by Read.
+ receiveIPPacketInfo bool
+
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -254,11 +258,17 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
}
e.mu.RLock()
receiveTOS := e.receiveTOS
+ receiveIPPacketInfo := e.receiveIPPacketInfo
e.mu.RUnlock()
if receiveTOS {
cm.HasTOS = true
cm.TOS = p.tos
}
+
+ if receiveIPPacketInfo {
+ cm.HasIPPacketInfo = true
+ cm.PacketInfo = p.packetInfo
+ }
return p.data.ToView(), cm, nil
}
@@ -495,6 +505,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
}
e.v6only = v
+ return nil
+
+ case tcpip.ReceiveIPPacketInfoOption:
+ e.mu.Lock()
+ e.receiveIPPacketInfo = v
+ e.mu.Unlock()
+ return nil
}
return nil
@@ -703,6 +720,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
+
+ case tcpip.ReceiveIPPacketInfoOption:
+ e.mu.RLock()
+ v := e.receiveIPPacketInfo
+ e.mu.RUnlock()
+ return v, nil
}
return false, tcpip.ErrUnknownProtocolOption
@@ -1247,6 +1270,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
switch r.NetProto {
case header.IPv4ProtocolNumber:
packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS()
+ packet.packetInfo.LocalAddr = r.LocalAddress
+ packet.packetInfo.DestinationAddr = r.RemoteAddress
+ packet.packetInfo.NIC = r.NICID()
}
packet.timestamp = e.stack.NowNanoseconds()
diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go
index 9f78f44f1..092b2fdca 100755
--- a/pkg/tcpip/transport/udp/udp_state_autogen.go
+++ b/pkg/tcpip/transport/udp/udp_state_autogen.go
@@ -14,6 +14,7 @@ func (x *udpPacket) save(m state.Map) {
m.SaveValue("data", data)
m.Save("udpPacketEntry", &x.udpPacketEntry)
m.Save("senderAddress", &x.senderAddress)
+ m.Save("packetInfo", &x.packetInfo)
m.Save("timestamp", &x.timestamp)
m.Save("tos", &x.tos)
}
@@ -22,6 +23,7 @@ func (x *udpPacket) afterLoad() {}
func (x *udpPacket) load(m state.Map) {
m.Load("udpPacketEntry", &x.udpPacketEntry)
m.Load("senderAddress", &x.senderAddress)
+ m.Load("packetInfo", &x.packetInfo)
m.Load("timestamp", &x.timestamp)
m.Load("tos", &x.tos)
m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) })
@@ -54,6 +56,7 @@ func (x *endpoint) save(m state.Map) {
m.Save("boundPortFlags", &x.boundPortFlags)
m.Save("sendTOS", &x.sendTOS)
m.Save("receiveTOS", &x.receiveTOS)
+ m.Save("receiveIPPacketInfo", &x.receiveIPPacketInfo)
m.Save("shutdownFlags", &x.shutdownFlags)
m.Save("multicastMemberships", &x.multicastMemberships)
m.Save("effectiveNetProtos", &x.effectiveNetProtos)
@@ -83,6 +86,7 @@ func (x *endpoint) load(m state.Map) {
m.Load("boundPortFlags", &x.boundPortFlags)
m.Load("sendTOS", &x.sendTOS)
m.Load("receiveTOS", &x.receiveTOS)
+ m.Load("receiveIPPacketInfo", &x.receiveIPPacketInfo)
m.Load("shutdownFlags", &x.shutdownFlags)
m.Load("multicastMemberships", &x.multicastMemberships)
m.Load("effectiveNetProtos", &x.effectiveNetProtos)