diff options
Diffstat (limited to 'pkg/sentry/socket/netstack')
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 24 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack_state_autogen.go | 19 |
2 files changed, 32 insertions, 11 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 0b1be1bd2..49a04e613 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -297,8 +297,9 @@ type socketOpsCommon struct { readView buffer.View // readCM holds control message information for the last packet read // from Endpoint. - readCM tcpip.ControlMessages - sender tcpip.FullAddress + readCM tcpip.ControlMessages + sender tcpip.FullAddress + linkPacketInfo tcpip.LinkPacketInfo // sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps // of returned messages can be returned via control messages. When @@ -447,8 +448,21 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error { } s.readView = nil s.sender = tcpip.FullAddress{} + s.linkPacketInfo = tcpip.LinkPacketInfo{} - v, cms, err := s.Endpoint.Read(&s.sender) + var v buffer.View + var cms tcpip.ControlMessages + var err *tcpip.Error + + switch e := s.Endpoint.(type) { + // The ordering of these interfaces matters. The most specific + // interfaces must be specified before the more generic Endpoint + // interface. + case tcpip.PacketEndpoint: + v, cms, err = e.ReadPacket(&s.sender, &s.linkPacketInfo) + case tcpip.Endpoint: + v, cms, err = e.Read(&s.sender) + } if err != nil { atomic.StoreUint32(&s.readViewHasData, 0) return syserr.TranslateNetstackError(err) @@ -2509,6 +2523,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq var addrLen uint32 if isPacket && senderRequested { addr, addrLen = ConvertAddress(s.family, s.sender) + switch v := addr.(type) { + case *linux.SockAddrLink: + v.Protocol = htons(uint16(s.linkPacketInfo.Protocol)) + } } if peek { diff --git a/pkg/sentry/socket/netstack/netstack_state_autogen.go b/pkg/sentry/socket/netstack/netstack_state_autogen.go index b7cf0b290..27e3ada76 100644 --- a/pkg/sentry/socket/netstack/netstack_state_autogen.go +++ b/pkg/sentry/socket/netstack/netstack_state_autogen.go @@ -45,6 +45,7 @@ func (x *socketOpsCommon) StateFields() []string { "readView", "readCM", "sender", + "linkPacketInfo", "sockOptTimestamp", "timestampValid", "timestampNS", @@ -66,10 +67,11 @@ func (x *socketOpsCommon) StateSave(m state.Sink) { m.Save(7, &x.readView) m.Save(8, &x.readCM) m.Save(9, &x.sender) - m.Save(10, &x.sockOptTimestamp) - m.Save(11, &x.timestampValid) - m.Save(12, &x.timestampNS) - m.Save(13, &x.sockOptInq) + m.Save(10, &x.linkPacketInfo) + m.Save(11, &x.sockOptTimestamp) + m.Save(12, &x.timestampValid) + m.Save(13, &x.timestampNS) + m.Save(14, &x.sockOptInq) } func (x *socketOpsCommon) afterLoad() {} @@ -85,10 +87,11 @@ func (x *socketOpsCommon) StateLoad(m state.Source) { m.Load(7, &x.readView) m.Load(8, &x.readCM) m.Load(9, &x.sender) - m.Load(10, &x.sockOptTimestamp) - m.Load(11, &x.timestampValid) - m.Load(12, &x.timestampNS) - m.Load(13, &x.sockOptInq) + m.Load(10, &x.linkPacketInfo) + m.Load(11, &x.sockOptTimestamp) + m.Load(12, &x.timestampValid) + m.Load(13, &x.timestampNS) + m.Load(14, &x.sockOptInq) } func (x *Stack) StateTypeName() string { |