summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/abi/linux/linux_abi_autogen_unsafe.go12
-rw-r--r--pkg/sentry/socket/netstack/netstack.go24
-rw-r--r--pkg/sentry/socket/netstack/netstack_state_autogen.go19
-rw-r--r--pkg/tcpip/tcpip.go19
-rw-r--r--pkg/tcpip/tcpip_state_autogen.go24
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go18
-rw-r--r--pkg/tcpip/transport/packet/packet_state_autogen.go3
7 files changed, 100 insertions, 19 deletions
diff --git a/pkg/abi/linux/linux_abi_autogen_unsafe.go b/pkg/abi/linux/linux_abi_autogen_unsafe.go
index 2c16ee9a7..4c8518125 100644
--- a/pkg/abi/linux/linux_abi_autogen_unsafe.go
+++ b/pkg/abi/linux/linux_abi_autogen_unsafe.go
@@ -124,12 +124,12 @@ func (s *Statx) UnmarshalBytes(src []byte) {
// Packed implements marshal.Marshallable.Packed.
//go:nosplit
func (s *Statx) Packed() bool {
- return s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed()
+ return s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed()
}
// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
func (s *Statx) MarshalUnsafe(dst []byte) {
- if s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed() {
+ if s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() {
safecopy.CopyIn(dst, unsafe.Pointer(s))
} else {
s.MarshalBytes(dst)
@@ -138,7 +138,7 @@ func (s *Statx) MarshalUnsafe(dst []byte) {
// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
func (s *Statx) UnmarshalUnsafe(src []byte) {
- if s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() {
+ if s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() {
safecopy.CopyOut(unsafe.Pointer(s), src)
} else {
s.UnmarshalBytes(src)
@@ -148,7 +148,7 @@ func (s *Statx) UnmarshalUnsafe(src []byte) {
// CopyOutN implements marshal.Marshallable.CopyOutN.
//go:nosplit
func (s *Statx) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
- if !s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() {
+ if !s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() {
// Type Statx doesn't have a packed layout in memory, fall back to MarshalBytes.
buf := task.CopyScratchBuffer(s.SizeBytes()) // escapes: okay.
s.MarshalBytes(buf) // escapes: fallback.
@@ -178,7 +178,7 @@ func (s *Statx) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
// CopyIn implements marshal.Marshallable.CopyIn.
//go:nosplit
func (s *Statx) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
- if !s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() {
+ if !s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() {
// Type Statx doesn't have a packed layout in memory, fall back to UnmarshalBytes.
buf := task.CopyScratchBuffer(s.SizeBytes()) // escapes: okay.
length, err := task.CopyInBytes(addr, buf) // escapes: okay.
@@ -204,7 +204,7 @@ func (s *Statx) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
// WriteTo implements io.WriterTo.WriteTo.
func (s *Statx) WriteTo(w io.Writer) (int64, error) {
- if !s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() {
+ if !s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed() {
// Type Statx doesn't have a packed layout in memory, fall back to MarshalBytes.
buf := make([]byte, s.SizeBytes())
s.MarshalBytes(buf)
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 0b1be1bd2..49a04e613 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -297,8 +297,9 @@ type socketOpsCommon struct {
readView buffer.View
// readCM holds control message information for the last packet read
// from Endpoint.
- readCM tcpip.ControlMessages
- sender tcpip.FullAddress
+ readCM tcpip.ControlMessages
+ sender tcpip.FullAddress
+ linkPacketInfo tcpip.LinkPacketInfo
// sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps
// of returned messages can be returned via control messages. When
@@ -447,8 +448,21 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error {
}
s.readView = nil
s.sender = tcpip.FullAddress{}
+ s.linkPacketInfo = tcpip.LinkPacketInfo{}
- v, cms, err := s.Endpoint.Read(&s.sender)
+ var v buffer.View
+ var cms tcpip.ControlMessages
+ var err *tcpip.Error
+
+ switch e := s.Endpoint.(type) {
+ // The ordering of these interfaces matters. The most specific
+ // interfaces must be specified before the more generic Endpoint
+ // interface.
+ case tcpip.PacketEndpoint:
+ v, cms, err = e.ReadPacket(&s.sender, &s.linkPacketInfo)
+ case tcpip.Endpoint:
+ v, cms, err = e.Read(&s.sender)
+ }
if err != nil {
atomic.StoreUint32(&s.readViewHasData, 0)
return syserr.TranslateNetstackError(err)
@@ -2509,6 +2523,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
var addrLen uint32
if isPacket && senderRequested {
addr, addrLen = ConvertAddress(s.family, s.sender)
+ switch v := addr.(type) {
+ case *linux.SockAddrLink:
+ v.Protocol = htons(uint16(s.linkPacketInfo.Protocol))
+ }
}
if peek {
diff --git a/pkg/sentry/socket/netstack/netstack_state_autogen.go b/pkg/sentry/socket/netstack/netstack_state_autogen.go
index b7cf0b290..27e3ada76 100644
--- a/pkg/sentry/socket/netstack/netstack_state_autogen.go
+++ b/pkg/sentry/socket/netstack/netstack_state_autogen.go
@@ -45,6 +45,7 @@ func (x *socketOpsCommon) StateFields() []string {
"readView",
"readCM",
"sender",
+ "linkPacketInfo",
"sockOptTimestamp",
"timestampValid",
"timestampNS",
@@ -66,10 +67,11 @@ func (x *socketOpsCommon) StateSave(m state.Sink) {
m.Save(7, &x.readView)
m.Save(8, &x.readCM)
m.Save(9, &x.sender)
- m.Save(10, &x.sockOptTimestamp)
- m.Save(11, &x.timestampValid)
- m.Save(12, &x.timestampNS)
- m.Save(13, &x.sockOptInq)
+ m.Save(10, &x.linkPacketInfo)
+ m.Save(11, &x.sockOptTimestamp)
+ m.Save(12, &x.timestampValid)
+ m.Save(13, &x.timestampNS)
+ m.Save(14, &x.sockOptInq)
}
func (x *socketOpsCommon) afterLoad() {}
@@ -85,10 +87,11 @@ func (x *socketOpsCommon) StateLoad(m state.Source) {
m.Load(7, &x.readView)
m.Load(8, &x.readCM)
m.Load(9, &x.sender)
- m.Load(10, &x.sockOptTimestamp)
- m.Load(11, &x.timestampValid)
- m.Load(12, &x.timestampNS)
- m.Load(13, &x.sockOptInq)
+ m.Load(10, &x.linkPacketInfo)
+ m.Load(11, &x.sockOptTimestamp)
+ m.Load(12, &x.timestampValid)
+ m.Load(13, &x.timestampNS)
+ m.Load(14, &x.sockOptInq)
}
func (x *Stack) StateTypeName() string {
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 71bcee785..48ad56d4d 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -549,6 +549,25 @@ type Endpoint interface {
SetOwner(owner PacketOwner)
}
+// LinkPacketInfo holds Link layer information for a received packet.
+//
+// +stateify savable
+type LinkPacketInfo struct {
+ // Protocol is the NetworkProtocolNumber for the packet.
+ Protocol NetworkProtocolNumber
+}
+
+// PacketEndpoint are additional methods that are only implemented by Packet
+// endpoints.
+type PacketEndpoint interface {
+ // ReadPacket reads a datagram/packet from the endpoint and optionally
+ // returns the sender and additional LinkPacketInfo.
+ //
+ // This method does not block if there is no data pending. It will also
+ // either return an error or data, never both.
+ ReadPacket(*FullAddress, *LinkPacketInfo) (buffer.View, ControlMessages, *Error)
+}
+
// EndpointInfo is the interface implemented by each endpoint info struct.
type EndpointInfo interface {
// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
diff --git a/pkg/tcpip/tcpip_state_autogen.go b/pkg/tcpip/tcpip_state_autogen.go
index 28d5ae82b..e79c0b491 100644
--- a/pkg/tcpip/tcpip_state_autogen.go
+++ b/pkg/tcpip/tcpip_state_autogen.go
@@ -85,6 +85,29 @@ func (x *ControlMessages) StateLoad(m state.Source) {
m.Load(9, &x.PacketInfo)
}
+func (x *LinkPacketInfo) StateTypeName() string {
+ return "pkg/tcpip.LinkPacketInfo"
+}
+
+func (x *LinkPacketInfo) StateFields() []string {
+ return []string{
+ "Protocol",
+ }
+}
+
+func (x *LinkPacketInfo) beforeSave() {}
+
+func (x *LinkPacketInfo) StateSave(m state.Sink) {
+ x.beforeSave()
+ m.Save(0, &x.Protocol)
+}
+
+func (x *LinkPacketInfo) afterLoad() {}
+
+func (x *LinkPacketInfo) StateLoad(m state.Source) {
+ m.Load(0, &x.Protocol)
+}
+
func (x *IPPacketInfo) StateTypeName() string {
return "pkg/tcpip.IPPacketInfo"
}
@@ -117,5 +140,6 @@ func (x *IPPacketInfo) StateLoad(m state.Source) {
func init() {
state.Register((*FullAddress)(nil))
state.Register((*ControlMessages)(nil))
+ state.Register((*LinkPacketInfo)(nil))
state.Register((*IPPacketInfo)(nil))
}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 92b487381..7b2083a09 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -45,6 +45,9 @@ type packet struct {
timestampNS int64
// senderAddr is the network address of the sender.
senderAddr tcpip.FullAddress
+ // packetInfo holds additional information like the protocol
+ // of the packet etc.
+ packetInfo tcpip.LinkPacketInfo
}
// endpoint is the packet socket implementation of tcpip.Endpoint. It is legal
@@ -151,8 +154,8 @@ func (ep *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (ep *endpoint) ModerateRecvBuf(copied int) {}
-// Read implements tcpip.Endpoint.Read.
-func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.PacketEndpoint.ReadPacket.
+func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
ep.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -177,9 +180,18 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes
*addr = packet.senderAddr
}
+ if info != nil {
+ *info = packet.packetInfo
+ }
+
return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
}
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ return ep.ReadPacket(addr, nil)
+}
+
func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// TODO(b/129292371): Implement.
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -428,12 +440,14 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
NIC: nicID,
Addr: tcpip.Address(hdr.SourceAddress()),
}
+ packet.packetInfo.Protocol = netProto
} else {
// Guess the would-be ethernet header.
packet.senderAddr = tcpip.FullAddress{
NIC: nicID,
Addr: tcpip.Address(localAddr),
}
+ packet.packetInfo.Protocol = netProto
}
if ep.cooked {
diff --git a/pkg/tcpip/transport/packet/packet_state_autogen.go b/pkg/tcpip/transport/packet/packet_state_autogen.go
index b7fa1cdc9..e13dd7827 100644
--- a/pkg/tcpip/transport/packet/packet_state_autogen.go
+++ b/pkg/tcpip/transport/packet/packet_state_autogen.go
@@ -17,6 +17,7 @@ func (x *packet) StateFields() []string {
"data",
"timestampNS",
"senderAddr",
+ "packetInfo",
}
}
@@ -29,6 +30,7 @@ func (x *packet) StateSave(m state.Sink) {
m.Save(0, &x.packetEntry)
m.Save(2, &x.timestampNS)
m.Save(3, &x.senderAddr)
+ m.Save(4, &x.packetInfo)
}
func (x *packet) afterLoad() {}
@@ -37,6 +39,7 @@ func (x *packet) StateLoad(m state.Source) {
m.Load(0, &x.packetEntry)
m.Load(2, &x.timestampNS)
m.Load(3, &x.senderAddr)
+ m.Load(4, &x.packetInfo)
m.LoadValue(1, new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) })
}