diff options
-rw-r--r-- | pkg/abi/linux/linux_abi_autogen_unsafe.go | 12 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 24 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack_state_autogen.go | 19 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 19 | ||||
-rw-r--r-- | pkg/tcpip/tcpip_state_autogen.go | 24 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/packet_state_autogen.go | 3 |
7 files changed, 100 insertions, 19 deletions
diff --git a/pkg/abi/linux/linux_abi_autogen_unsafe.go b/pkg/abi/linux/linux_abi_autogen_unsafe.go index 2c16ee9a7..4c8518125 100644 --- a/pkg/abi/linux/linux_abi_autogen_unsafe.go +++ b/pkg/abi/linux/linux_abi_autogen_unsafe.go @@ -124,12 +124,12 @@ func (s *Statx) UnmarshalBytes(src []byte) { // Packed implements marshal.Marshallable.Packed. //go:nosplit func (s *Statx) Packed() bool { - return s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed() + return s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() } // MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. func (s *Statx) MarshalUnsafe(dst []byte) { - if s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed() { + if s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() { safecopy.CopyIn(dst, unsafe.Pointer(s)) } else { s.MarshalBytes(dst) @@ -138,7 +138,7 @@ func (s *Statx) MarshalUnsafe(dst []byte) { // UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. func (s *Statx) UnmarshalUnsafe(src []byte) { - if s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() { + if s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() { safecopy.CopyOut(unsafe.Pointer(s), src) } else { s.UnmarshalBytes(src) @@ -148,7 +148,7 @@ func (s *Statx) UnmarshalUnsafe(src []byte) { // CopyOutN implements marshal.Marshallable.CopyOutN. //go:nosplit func (s *Statx) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { - if !s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() { + if !s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() { // Type Statx doesn't have a packed layout in memory, fall back to MarshalBytes. buf := task.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. s.MarshalBytes(buf) // escapes: fallback. @@ -178,7 +178,7 @@ func (s *Statx) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { // CopyIn implements marshal.Marshallable.CopyIn. //go:nosplit func (s *Statx) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { - if !s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() { + if !s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() { // Type Statx doesn't have a packed layout in memory, fall back to UnmarshalBytes. buf := task.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. length, err := task.CopyInBytes(addr, buf) // escapes: okay. @@ -204,7 +204,7 @@ func (s *Statx) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { // WriteTo implements io.WriterTo.WriteTo. func (s *Statx) WriteTo(w io.Writer) (int64, error) { - if !s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() { + if !s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed() { // Type Statx doesn't have a packed layout in memory, fall back to MarshalBytes. buf := make([]byte, s.SizeBytes()) s.MarshalBytes(buf) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 0b1be1bd2..49a04e613 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -297,8 +297,9 @@ type socketOpsCommon struct { readView buffer.View // readCM holds control message information for the last packet read // from Endpoint. - readCM tcpip.ControlMessages - sender tcpip.FullAddress + readCM tcpip.ControlMessages + sender tcpip.FullAddress + linkPacketInfo tcpip.LinkPacketInfo // sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps // of returned messages can be returned via control messages. When @@ -447,8 +448,21 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error { } s.readView = nil s.sender = tcpip.FullAddress{} + s.linkPacketInfo = tcpip.LinkPacketInfo{} - v, cms, err := s.Endpoint.Read(&s.sender) + var v buffer.View + var cms tcpip.ControlMessages + var err *tcpip.Error + + switch e := s.Endpoint.(type) { + // The ordering of these interfaces matters. The most specific + // interfaces must be specified before the more generic Endpoint + // interface. + case tcpip.PacketEndpoint: + v, cms, err = e.ReadPacket(&s.sender, &s.linkPacketInfo) + case tcpip.Endpoint: + v, cms, err = e.Read(&s.sender) + } if err != nil { atomic.StoreUint32(&s.readViewHasData, 0) return syserr.TranslateNetstackError(err) @@ -2509,6 +2523,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq var addrLen uint32 if isPacket && senderRequested { addr, addrLen = ConvertAddress(s.family, s.sender) + switch v := addr.(type) { + case *linux.SockAddrLink: + v.Protocol = htons(uint16(s.linkPacketInfo.Protocol)) + } } if peek { diff --git a/pkg/sentry/socket/netstack/netstack_state_autogen.go b/pkg/sentry/socket/netstack/netstack_state_autogen.go index b7cf0b290..27e3ada76 100644 --- a/pkg/sentry/socket/netstack/netstack_state_autogen.go +++ b/pkg/sentry/socket/netstack/netstack_state_autogen.go @@ -45,6 +45,7 @@ func (x *socketOpsCommon) StateFields() []string { "readView", "readCM", "sender", + "linkPacketInfo", "sockOptTimestamp", "timestampValid", "timestampNS", @@ -66,10 +67,11 @@ func (x *socketOpsCommon) StateSave(m state.Sink) { m.Save(7, &x.readView) m.Save(8, &x.readCM) m.Save(9, &x.sender) - m.Save(10, &x.sockOptTimestamp) - m.Save(11, &x.timestampValid) - m.Save(12, &x.timestampNS) - m.Save(13, &x.sockOptInq) + m.Save(10, &x.linkPacketInfo) + m.Save(11, &x.sockOptTimestamp) + m.Save(12, &x.timestampValid) + m.Save(13, &x.timestampNS) + m.Save(14, &x.sockOptInq) } func (x *socketOpsCommon) afterLoad() {} @@ -85,10 +87,11 @@ func (x *socketOpsCommon) StateLoad(m state.Source) { m.Load(7, &x.readView) m.Load(8, &x.readCM) m.Load(9, &x.sender) - m.Load(10, &x.sockOptTimestamp) - m.Load(11, &x.timestampValid) - m.Load(12, &x.timestampNS) - m.Load(13, &x.sockOptInq) + m.Load(10, &x.linkPacketInfo) + m.Load(11, &x.sockOptTimestamp) + m.Load(12, &x.timestampValid) + m.Load(13, &x.timestampNS) + m.Load(14, &x.sockOptInq) } func (x *Stack) StateTypeName() string { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 71bcee785..48ad56d4d 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -549,6 +549,25 @@ type Endpoint interface { SetOwner(owner PacketOwner) } +// LinkPacketInfo holds Link layer information for a received packet. +// +// +stateify savable +type LinkPacketInfo struct { + // Protocol is the NetworkProtocolNumber for the packet. + Protocol NetworkProtocolNumber +} + +// PacketEndpoint are additional methods that are only implemented by Packet +// endpoints. +type PacketEndpoint interface { + // ReadPacket reads a datagram/packet from the endpoint and optionally + // returns the sender and additional LinkPacketInfo. + // + // This method does not block if there is no data pending. It will also + // either return an error or data, never both. + ReadPacket(*FullAddress, *LinkPacketInfo) (buffer.View, ControlMessages, *Error) +} + // EndpointInfo is the interface implemented by each endpoint info struct. type EndpointInfo interface { // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo diff --git a/pkg/tcpip/tcpip_state_autogen.go b/pkg/tcpip/tcpip_state_autogen.go index 28d5ae82b..e79c0b491 100644 --- a/pkg/tcpip/tcpip_state_autogen.go +++ b/pkg/tcpip/tcpip_state_autogen.go @@ -85,6 +85,29 @@ func (x *ControlMessages) StateLoad(m state.Source) { m.Load(9, &x.PacketInfo) } +func (x *LinkPacketInfo) StateTypeName() string { + return "pkg/tcpip.LinkPacketInfo" +} + +func (x *LinkPacketInfo) StateFields() []string { + return []string{ + "Protocol", + } +} + +func (x *LinkPacketInfo) beforeSave() {} + +func (x *LinkPacketInfo) StateSave(m state.Sink) { + x.beforeSave() + m.Save(0, &x.Protocol) +} + +func (x *LinkPacketInfo) afterLoad() {} + +func (x *LinkPacketInfo) StateLoad(m state.Source) { + m.Load(0, &x.Protocol) +} + func (x *IPPacketInfo) StateTypeName() string { return "pkg/tcpip.IPPacketInfo" } @@ -117,5 +140,6 @@ func (x *IPPacketInfo) StateLoad(m state.Source) { func init() { state.Register((*FullAddress)(nil)) state.Register((*ControlMessages)(nil)) + state.Register((*LinkPacketInfo)(nil)) state.Register((*IPPacketInfo)(nil)) } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 92b487381..7b2083a09 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -45,6 +45,9 @@ type packet struct { timestampNS int64 // senderAddr is the network address of the sender. senderAddr tcpip.FullAddress + // packetInfo holds additional information like the protocol + // of the packet etc. + packetInfo tcpip.LinkPacketInfo } // endpoint is the packet socket implementation of tcpip.Endpoint. It is legal @@ -151,8 +154,8 @@ func (ep *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (ep *endpoint) ModerateRecvBuf(copied int) {} -// Read implements tcpip.Endpoint.Read. -func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +// Read implements tcpip.PacketEndpoint.ReadPacket. +func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { ep.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -177,9 +180,18 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes *addr = packet.senderAddr } + if info != nil { + *info = packet.packetInfo + } + return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil } +// Read implements tcpip.Endpoint.Read. +func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + return ep.ReadPacket(addr, nil) +} + func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // TODO(b/129292371): Implement. return 0, nil, tcpip.ErrInvalidOptionValue @@ -428,12 +440,14 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, NIC: nicID, Addr: tcpip.Address(hdr.SourceAddress()), } + packet.packetInfo.Protocol = netProto } else { // Guess the would-be ethernet header. packet.senderAddr = tcpip.FullAddress{ NIC: nicID, Addr: tcpip.Address(localAddr), } + packet.packetInfo.Protocol = netProto } if ep.cooked { diff --git a/pkg/tcpip/transport/packet/packet_state_autogen.go b/pkg/tcpip/transport/packet/packet_state_autogen.go index b7fa1cdc9..e13dd7827 100644 --- a/pkg/tcpip/transport/packet/packet_state_autogen.go +++ b/pkg/tcpip/transport/packet/packet_state_autogen.go @@ -17,6 +17,7 @@ func (x *packet) StateFields() []string { "data", "timestampNS", "senderAddr", + "packetInfo", } } @@ -29,6 +30,7 @@ func (x *packet) StateSave(m state.Sink) { m.Save(0, &x.packetEntry) m.Save(2, &x.timestampNS) m.Save(3, &x.senderAddr) + m.Save(4, &x.packetInfo) } func (x *packet) afterLoad() {} @@ -37,6 +39,7 @@ func (x *packet) StateLoad(m state.Source) { m.Load(0, &x.packetEntry) m.Load(2, &x.timestampNS) m.Load(3, &x.senderAddr) + m.Load(4, &x.packetInfo) m.LoadValue(1, new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) } |