summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/control/control.go53
-rw-r--r--pkg/sentry/socket/hostinet/socket.go4
-rw-r--r--pkg/sentry/socket/netstack/netstack.go27
-rw-r--r--pkg/sentry/socket/socket.go74
-rw-r--r--pkg/sentry/socket/socket_state_autogen.go54
5 files changed, 167 insertions, 45 deletions
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index c284efde5..5b81e8379 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -344,32 +343,32 @@ func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte {
}
// PackIPPacketInfo packs an IP_PKTINFO socket control message.
-func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte {
- var p linux.ControlMessageIPPacketInfo
- p.NIC = int32(packetInfo.NIC)
- copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
- copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
-
+func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketInfo, buf []byte) []byte {
return putCmsgStruct(
buf,
linux.SOL_IP,
linux.IP_PKTINFO,
t.Arch().Width(),
- p,
+ packetInfo,
)
}
// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message.
-func PackOriginalDstAddress(t *kernel.Task, family int, originalDstAddress tcpip.FullAddress, buf []byte) []byte {
- p, _ := socket.ConvertAddress(family, originalDstAddress)
- level := uint32(linux.SOL_IP)
- optType := uint32(linux.IP_RECVORIGDSTADDR)
- if family == linux.AF_INET6 {
+func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte {
+ var level uint32
+ var optType uint32
+ switch originalDstAddress.(type) {
+ case *linux.SockAddrInet:
+ level = linux.SOL_IP
+ optType = linux.IP_RECVORIGDSTADDR
+ case *linux.SockAddrInet6:
level = linux.SOL_IPV6
optType = linux.IPV6_RECVORIGDSTADDR
+ default:
+ panic("invalid address type, must be an IP address for IP_RECVORIGINALDSTADDR cmsg")
}
return putCmsgStruct(
- buf, level, optType, t.Arch().Width(), p)
+ buf, level, optType, t.Arch().Width(), originalDstAddress)
}
// PackControlMessages packs control messages into the given buffer.
@@ -378,7 +377,7 @@ func PackOriginalDstAddress(t *kernel.Task, family int, originalDstAddress tcpip
//
// Note that some control messages may be truncated if they do not fit under
// the capacity of buf.
-func PackControlMessages(t *kernel.Task, family int, cmsgs socket.ControlMessages, buf []byte) []byte {
+func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte {
if cmsgs.IP.HasTimestamp {
buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf)
}
@@ -397,11 +396,11 @@ func PackControlMessages(t *kernel.Task, family int, cmsgs socket.ControlMessage
}
if cmsgs.IP.HasIPPacketInfo {
- buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf)
+ buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf)
}
- if cmsgs.IP.HasOriginalDstAddress {
- buf = PackOriginalDstAddress(t, family, cmsgs.IP.OriginalDstAddress, buf)
+ if cmsgs.IP.OriginalDstAddress != nil {
+ buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf)
}
return buf
@@ -433,17 +432,15 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
space += cmsgSpace(t, linux.SizeOfControlMessageTClass)
}
- return space
-}
+ if cmsgs.IP.HasIPPacketInfo {
+ space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo)
+ }
-// NewIPPacketInfo returns the IPPacketInfo struct.
-func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo {
- var p tcpip.IPPacketInfo
- p.NIC = tcpip.NICID(packetInfo.NIC)
- copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:])
- copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:])
+ if cmsgs.IP.OriginalDstAddress != nil {
+ space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes())
+ }
- return p
+ return space
}
// Parse parses a raw socket control message into portable objects.
@@ -529,7 +526,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
var packetInfo linux.ControlMessageIPPacketInfo
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
- cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo)
+ cmsgs.IP.PacketInfo = packetInfo
i += binary.AlignUp(length, width)
default:
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index a35850a8f..9b337d286 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -523,7 +523,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
controlMessages.IP.HasIPPacketInfo = true
var packetInfo linux.ControlMessageIPPacketInfo
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
- controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo)
+ controlMessages.IP.PacketInfo = packetInfo
}
case syscall.SOL_IPV6:
@@ -551,7 +551,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
}
controlBuf := make([]byte, 0, space)
// PackControlMessages will append up to space bytes to controlBuf.
- controlBuf = control.PackControlMessages(t, s.family, controlMessages, controlBuf)
+ controlBuf = control.PackControlMessages(t, controlMessages, controlBuf)
sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) {
// Refuse to do anything if any part of src.Addrs was unusable.
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 37c3faa57..af75e2670 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -320,7 +320,7 @@ type socketOpsCommon struct {
readView buffer.View
// readCM holds control message information for the last packet read
// from Endpoint.
- readCM tcpip.ControlMessages
+ readCM socket.IPControlMessages
sender tcpip.FullAddress
linkPacketInfo tcpip.LinkPacketInfo
@@ -408,7 +408,7 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error {
}
s.readView = v
- s.readCM = cms
+ s.readCM = socket.NewIPControlMessages(s.family, cms)
atomic.StoreUint32(&s.readViewHasData, 1)
return nil
@@ -2736,7 +2736,7 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
// We need to peek beyond the first message.
dst = dst.DropFirst(n)
num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) {
- n, _, err := s.Endpoint.Peek(dsts)
+ n, err := s.Endpoint.Peek(dsts)
// TODO(b/78348848): Handle peek timestamp.
if err != nil {
return int64(n), syserr.TranslateNetstackError(err).ToError()
@@ -2780,17 +2780,16 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
return socket.ControlMessages{
- IP: tcpip.ControlMessages{
- HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
- Timestamp: s.readCM.Timestamp,
- HasTOS: s.readCM.HasTOS,
- TOS: s.readCM.TOS,
- HasTClass: s.readCM.HasTClass,
- TClass: s.readCM.TClass,
- HasIPPacketInfo: s.readCM.HasIPPacketInfo,
- PacketInfo: s.readCM.PacketInfo,
- HasOriginalDstAddress: s.readCM.HasOriginalDstAddress,
- OriginalDstAddress: s.readCM.OriginalDstAddress,
+ IP: socket.IPControlMessages{
+ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
+ Timestamp: s.readCM.Timestamp,
+ HasTOS: s.readCM.HasTOS,
+ TOS: s.readCM.TOS,
+ HasTClass: s.readCM.HasTClass,
+ TClass: s.readCM.TClass,
+ HasIPPacketInfo: s.readCM.HasIPPacketInfo,
+ PacketInfo: s.readCM.PacketInfo,
+ OriginalDstAddress: s.readCM.OriginalDstAddress,
},
}
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 9049e8a21..bcc426e33 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -44,7 +44,79 @@ import (
// control messages.
type ControlMessages struct {
Unix transport.ControlMessages
- IP tcpip.ControlMessages
+ IP IPControlMessages
+}
+
+// packetInfoToLinux converts IPPacketInfo from tcpip format to Linux format.
+func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPacketInfo {
+ var p linux.ControlMessageIPPacketInfo
+ p.NIC = int32(packetInfo.NIC)
+ copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
+ copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
+ return p
+}
+
+// NewIPControlMessages converts the tcpip ControlMessgaes (which does not
+// have Linux specific format) to Linux format.
+func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessages {
+ var orgDstAddr linux.SockAddr
+ if cmgs.HasOriginalDstAddress {
+ orgDstAddr, _ = ConvertAddress(family, cmgs.OriginalDstAddress)
+ }
+ return IPControlMessages{
+ HasTimestamp: cmgs.HasTimestamp,
+ Timestamp: cmgs.Timestamp,
+ HasInq: cmgs.HasInq,
+ Inq: cmgs.Inq,
+ HasTOS: cmgs.HasTOS,
+ TOS: cmgs.TOS,
+ HasTClass: cmgs.HasTClass,
+ TClass: cmgs.TClass,
+ HasIPPacketInfo: cmgs.HasIPPacketInfo,
+ PacketInfo: packetInfoToLinux(cmgs.PacketInfo),
+ OriginalDstAddress: orgDstAddr,
+ }
+}
+
+// IPControlMessages contains socket control messages for IP sockets.
+// This can contain Linux specific structures unlike tcpip.ControlMessages.
+//
+// +stateify savable
+type IPControlMessages struct {
+ // HasTimestamp indicates whether Timestamp is valid/set.
+ HasTimestamp bool
+
+ // Timestamp is the time (in ns) that the last packet used to create
+ // the read data was received.
+ Timestamp int64
+
+ // HasInq indicates whether Inq is valid/set.
+ HasInq bool
+
+ // Inq is the number of bytes ready to be received.
+ Inq int32
+
+ // HasTOS indicates whether Tos is valid/set.
+ HasTOS bool
+
+ // TOS is the IPv4 type of service of the associated packet.
+ TOS uint8
+
+ // HasTClass indicates whether TClass is valid/set.
+ HasTClass bool
+
+ // TClass is the IPv6 traffic class of the associated packet.
+ TClass uint32
+
+ // HasIPPacketInfo indicates whether PacketInfo is set.
+ HasIPPacketInfo bool
+
+ // PacketInfo holds interface and address data on an incoming packet.
+ PacketInfo linux.ControlMessageIPPacketInfo
+
+ // OriginalDestinationAddress holds the original destination address
+ // and port of the incoming packet.
+ OriginalDstAddress linux.SockAddr
}
// Release releases Unix domain socket credentials and rights.
diff --git a/pkg/sentry/socket/socket_state_autogen.go b/pkg/sentry/socket/socket_state_autogen.go
index b5277ffab..4f854f99f 100644
--- a/pkg/sentry/socket/socket_state_autogen.go
+++ b/pkg/sentry/socket/socket_state_autogen.go
@@ -6,6 +6,59 @@ import (
"gvisor.dev/gvisor/pkg/state"
)
+func (i *IPControlMessages) StateTypeName() string {
+ return "pkg/sentry/socket.IPControlMessages"
+}
+
+func (i *IPControlMessages) StateFields() []string {
+ return []string{
+ "HasTimestamp",
+ "Timestamp",
+ "HasInq",
+ "Inq",
+ "HasTOS",
+ "TOS",
+ "HasTClass",
+ "TClass",
+ "HasIPPacketInfo",
+ "PacketInfo",
+ "OriginalDstAddress",
+ }
+}
+
+func (i *IPControlMessages) beforeSave() {}
+
+func (i *IPControlMessages) StateSave(stateSinkObject state.Sink) {
+ i.beforeSave()
+ stateSinkObject.Save(0, &i.HasTimestamp)
+ stateSinkObject.Save(1, &i.Timestamp)
+ stateSinkObject.Save(2, &i.HasInq)
+ stateSinkObject.Save(3, &i.Inq)
+ stateSinkObject.Save(4, &i.HasTOS)
+ stateSinkObject.Save(5, &i.TOS)
+ stateSinkObject.Save(6, &i.HasTClass)
+ stateSinkObject.Save(7, &i.TClass)
+ stateSinkObject.Save(8, &i.HasIPPacketInfo)
+ stateSinkObject.Save(9, &i.PacketInfo)
+ stateSinkObject.Save(10, &i.OriginalDstAddress)
+}
+
+func (i *IPControlMessages) afterLoad() {}
+
+func (i *IPControlMessages) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &i.HasTimestamp)
+ stateSourceObject.Load(1, &i.Timestamp)
+ stateSourceObject.Load(2, &i.HasInq)
+ stateSourceObject.Load(3, &i.Inq)
+ stateSourceObject.Load(4, &i.HasTOS)
+ stateSourceObject.Load(5, &i.TOS)
+ stateSourceObject.Load(6, &i.HasTClass)
+ stateSourceObject.Load(7, &i.TClass)
+ stateSourceObject.Load(8, &i.HasIPPacketInfo)
+ stateSourceObject.Load(9, &i.PacketInfo)
+ stateSourceObject.Load(10, &i.OriginalDstAddress)
+}
+
func (to *SendReceiveTimeout) StateTypeName() string {
return "pkg/sentry/socket.SendReceiveTimeout"
}
@@ -33,5 +86,6 @@ func (to *SendReceiveTimeout) StateLoad(stateSourceObject state.Source) {
}
func init() {
+ state.Register((*IPControlMessages)(nil))
state.Register((*SendReceiveTimeout)(nil))
}