diff options
author | gVisor bot <gvisor-bot@google.com> | 2020-12-10 00:26:19 +0000 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-12-10 00:26:19 +0000 |
commit | 23f464469bb7ccd79928a7a749219336b7bcb2c0 (patch) | |
tree | e2e9e61f29e18cadb71aa78302865d807733f95a /pkg | |
parent | 1faf5799ed1fb83a36c6433b22a7c8f495395943 (diff) | |
parent | 92ca72ecb73d91e9def31e7f9835adf7a50b3d65 (diff) |
Merge release-20201130.0-74-g92ca72ecb (automated)
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/sentry/socket/control/control.go | 19 | ||||
-rw-r--r-- | pkg/sentry/socket/hostinet/socket.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 57 | ||||
-rw-r--r-- | pkg/sentry/syscalls/linux/sys_socket.go | 3 | ||||
-rw-r--r-- | pkg/sentry/syscalls/linux/vfs2/socket.go | 3 | ||||
-rw-r--r-- | pkg/sentry/vfs/mount_unsafe.go | 36 | ||||
-rw-r--r-- | pkg/sentry/vfs/save_restore.go | 6 | ||||
-rw-r--r-- | pkg/sentry/vfs/vfs_unsafe_state_autogen.go | 37 | ||||
-rw-r--r-- | pkg/sync/runtime_unsafe.go | 53 | ||||
-rw-r--r-- | pkg/tcpip/socketops.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/tcpip_state_autogen.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_state_autogen.go | 19 |
14 files changed, 198 insertions, 86 deletions
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 70ccf77a7..c284efde5 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -359,13 +359,26 @@ func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) ) } +// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message. +func PackOriginalDstAddress(t *kernel.Task, family int, originalDstAddress tcpip.FullAddress, buf []byte) []byte { + p, _ := socket.ConvertAddress(family, originalDstAddress) + level := uint32(linux.SOL_IP) + optType := uint32(linux.IP_RECVORIGDSTADDR) + if family == linux.AF_INET6 { + level = linux.SOL_IPV6 + optType = linux.IPV6_RECVORIGDSTADDR + } + return putCmsgStruct( + buf, level, optType, t.Arch().Width(), p) +} + // PackControlMessages packs control messages into the given buffer. // // We skip control messages specific to Unix domain sockets. // // Note that some control messages may be truncated if they do not fit under // the capacity of buf. -func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte { +func PackControlMessages(t *kernel.Task, family int, cmsgs socket.ControlMessages, buf []byte) []byte { if cmsgs.IP.HasTimestamp { buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf) } @@ -387,6 +400,10 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf) } + if cmsgs.IP.HasOriginalDstAddress { + buf = PackOriginalDstAddress(t, family, cmsgs.IP.OriginalDstAddress, buf) + } + return buf } diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 7d3c4a01c..a35850a8f 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -551,7 +551,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b } controlBuf := make([]byte, 0, space) // PackControlMessages will append up to space bytes to controlBuf. - controlBuf = control.PackControlMessages(t, controlMessages, controlBuf) + controlBuf = control.PackControlMessages(t, s.family, controlMessages, controlBuf) sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) { // Refuse to do anything if any part of src.Addrs was unusable. diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 3cc0d4f0f..37c3faa57 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1418,6 +1418,14 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass())) return &v, nil + case linux.IPV6_RECVORIGDSTADDR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress())) + return &v, nil + case linux.IP6T_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet6{})) { return nil, syserr.ErrInvalidArgument @@ -1599,6 +1607,14 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded())) return &v, nil + case linux.IP_RECVORIGDSTADDR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress())) + return &v, nil + case linux.SO_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet{})) { return nil, syserr.ErrInvalidArgument @@ -2094,6 +2110,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name t.Kernel().EmitUnimplementedEvent(t) + case linux.IPV6_RECVORIGDSTADDR: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := int32(usermem.ByteOrder.Uint32(optVal)) + + ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) + return nil + case linux.IPV6_TCLASS: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -2325,6 +2350,18 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in ep.SocketOptions().SetHeaderIncluded(v != 0) return nil + case linux.IP_RECVORIGDSTADDR: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + + ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) + return nil + case linux.IPT_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIPTReplace { return syserr.ErrInvalidArgument @@ -2363,7 +2400,6 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in linux.IP_RECVERR, linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, - linux.IP_RECVORIGDSTADDR, linux.IP_RECVTTL, linux.IP_RETOPTS, linux.IP_TRANSPARENT, @@ -2441,7 +2477,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) { linux.IPV6_RECVFRAGSIZE, linux.IPV6_RECVHOPLIMIT, linux.IPV6_RECVHOPOPTS, - linux.IPV6_RECVORIGDSTADDR, linux.IPV6_RECVPATHMTU, linux.IPV6_RECVPKTINFO, linux.IPV6_RECVRTHDR, @@ -2746,14 +2781,16 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq func (s *socketOpsCommon) controlMessages() socket.ControlMessages { return socket.ControlMessages{ IP: tcpip.ControlMessages{ - HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, - Timestamp: s.readCM.Timestamp, - HasTOS: s.readCM.HasTOS, - TOS: s.readCM.TOS, - HasTClass: s.readCM.HasTClass, - TClass: s.readCM.TClass, - HasIPPacketInfo: s.readCM.HasIPPacketInfo, - PacketInfo: s.readCM.PacketInfo, + HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, + Timestamp: s.readCM.Timestamp, + HasTOS: s.readCM.HasTOS, + TOS: s.readCM.TOS, + HasTClass: s.readCM.HasTClass, + TClass: s.readCM.TClass, + HasIPPacketInfo: s.readCM.HasIPPacketInfo, + PacketInfo: s.readCM.PacketInfo, + HasOriginalDstAddress: s.readCM.HasOriginalDstAddress, + OriginalDstAddress: s.readCM.OriginalDstAddress, }, } } diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 9cd052c3d..754a5df24 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -784,8 +784,9 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i } defer cms.Release(t) + family, _, _ := s.Type() controlData := make([]byte, 0, msg.ControlLen) - controlData = control.PackControlMessages(t, cms, controlData) + controlData = control.PackControlMessages(t, family, cms, controlData) if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() { creds, _ := cms.Unix.Credentials.(control.SCMCredentials) diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 7b33b3f59..c40892b9f 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -787,8 +787,9 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla } defer cms.Release(t) + family, _, _ := s.Type() controlData := make([]byte, 0, msg.ControlLen) - controlData = control.PackControlMessages(t, cms, controlData) + controlData = control.PackControlMessages(t, family, cms, controlData) if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() { creds, _ := cms.Unix.Credentials.(control.SCMCredentials) diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index cb48c37a1..0df023713 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -12,11 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build go1.12 -// +build !go1.17 - -// Check go:linkname function signatures when updating Go version. - package vfs import ( @@ -41,6 +36,15 @@ type mountKey struct { point unsafe.Pointer // *Dentry } +var ( + mountKeyHasher = sync.MapKeyHasher(map[mountKey]struct{}(nil)) + mountKeySeed = sync.RandUintptr() +) + +func (k *mountKey) hash() uintptr { + return mountKeyHasher(gohacks.Noescape(unsafe.Pointer(k)), mountKeySeed) +} + func (mnt *Mount) parent() *Mount { return (*Mount)(atomic.LoadPointer(&mnt.key.parent)) } @@ -56,23 +60,17 @@ func (mnt *Mount) getKey() VirtualDentry { } } -func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() } - // Invariant: mnt.key.parent == nil. vd.Ok(). func (mnt *Mount) setKey(vd VirtualDentry) { atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount)) atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry)) } -func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) } - // mountTable maps (mount parent, mount point) pairs to mounts. It supports // efficient concurrent lookup, even in the presence of concurrent mutators // (provided mutation is sufficiently uncommon). // // mountTable.Init() must be called on new mountTables before use. -// -// +stateify savable type mountTable struct { // mountTable is implemented as a seqcount-protected hash table that // resolves collisions with linear probing, featuring Robin Hood insertion @@ -84,8 +82,7 @@ type mountTable struct { // intrinsics and inline assembly, limiting the performance of this // approach.) - seq sync.SeqCount `state:"nosave"` - seed uint32 // for hashing keys + seq sync.SeqCount `state:"nosave"` // size holds both length (number of elements) and capacity (number of // slots): capacity is stored as its base-2 log (referred to as order) in @@ -150,7 +147,6 @@ func init() { // Init must be called exactly once on each mountTable before use. func (mt *mountTable) Init() { - mt.seed = rand32() mt.size = mtInitOrder mt.slots = newMountTableSlots(mtInitCap) } @@ -167,7 +163,7 @@ func newMountTableSlots(cap uintptr) unsafe.Pointer { // Lookup may be called even if there are concurrent mutators of mt. func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount { key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)} - hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes) + hash := key.hash() loop: for { @@ -247,7 +243,7 @@ func (mt *mountTable) Insert(mount *Mount) { // * mt.seq must be in a writer critical section. // * mt must not already contain a Mount with the same mount point and parent. func (mt *mountTable) insertSeqed(mount *Mount) { - hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) + hash := mount.key.hash() // We're under the maximum load factor if: // @@ -346,7 +342,7 @@ func (mt *mountTable) Remove(mount *Mount) { // * mt.seq must be in a writer critical section. // * mt must contain mount. func (mt *mountTable) removeSeqed(mount *Mount) { - hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) + hash := mount.key.hash() tcap := uintptr(1) << (mt.size & mtSizeOrderMask) mask := tcap - 1 slots := mt.slots @@ -386,9 +382,3 @@ func (mt *mountTable) removeSeqed(mount *Mount) { off = (off + mountSlotBytes) & offmask } } - -//go:linkname memhash runtime.memhash -func memhash(p unsafe.Pointer, seed, s uintptr) uintptr - -//go:linkname rand32 runtime.fastrand -func rand32() uint32 diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go index 8f070ed53..8998a82dd 100644 --- a/pkg/sentry/vfs/save_restore.go +++ b/pkg/sentry/vfs/save_restore.go @@ -101,6 +101,9 @@ func (vfs *VirtualFilesystem) saveMounts() []*Mount { return mounts } +// saveKey is called by stateify. +func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() } + // loadMounts is called by stateify. func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) { if mounts == nil { @@ -112,6 +115,9 @@ func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) { } } +// loadKey is called by stateify. +func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) } + func (mnt *Mount) afterLoad() { if atomic.LoadInt64(&mnt.refs) != 0 { refsvfs2.Register(mnt) diff --git a/pkg/sentry/vfs/vfs_unsafe_state_autogen.go b/pkg/sentry/vfs/vfs_unsafe_state_autogen.go index d23733ee7..20f06c953 100644 --- a/pkg/sentry/vfs/vfs_unsafe_state_autogen.go +++ b/pkg/sentry/vfs/vfs_unsafe_state_autogen.go @@ -1,40 +1,3 @@ // automatically generated by stateify. -// +build go1.12 -// +build !go1.17 - package vfs - -import ( - "gvisor.dev/gvisor/pkg/state" -) - -func (mt *mountTable) StateTypeName() string { - return "pkg/sentry/vfs.mountTable" -} - -func (mt *mountTable) StateFields() []string { - return []string{ - "seed", - "size", - } -} - -func (mt *mountTable) beforeSave() {} - -func (mt *mountTable) StateSave(stateSinkObject state.Sink) { - mt.beforeSave() - stateSinkObject.Save(0, &mt.seed) - stateSinkObject.Save(1, &mt.size) -} - -func (mt *mountTable) afterLoad() {} - -func (mt *mountTable) StateLoad(stateSourceObject state.Source) { - stateSourceObject.Load(0, &mt.seed) - stateSourceObject.Load(1, &mt.size) -} - -func init() { - state.Register((*mountTable)(nil)) -} diff --git a/pkg/sync/runtime_unsafe.go b/pkg/sync/runtime_unsafe.go index 7ad6a4434..e925e2e5b 100644 --- a/pkg/sync/runtime_unsafe.go +++ b/pkg/sync/runtime_unsafe.go @@ -11,6 +11,8 @@ package sync import ( + "fmt" + "reflect" "unsafe" ) @@ -61,6 +63,57 @@ const ( TraceEvGoBlockSelect byte = 24 ) +// Rand32 returns a non-cryptographically-secure random uint32. +func Rand32() uint32 { + return fastrand() +} + +// Rand64 returns a non-cryptographically-secure random uint64. +func Rand64() uint64 { + return uint64(fastrand())<<32 | uint64(fastrand()) +} + +//go:linkname fastrand runtime.fastrand +func fastrand() uint32 + +// RandUintptr returns a non-cryptographically-secure random uintptr. +func RandUintptr() uintptr { + if unsafe.Sizeof(uintptr(0)) == 4 { + return uintptr(Rand32()) + } + return uintptr(Rand64()) +} + +// MapKeyHasher returns a hash function for pointers of m's key type. +// +// Preconditions: m must be a map. +func MapKeyHasher(m interface{}) func(unsafe.Pointer, uintptr) uintptr { + if rtyp := reflect.TypeOf(m); rtyp.Kind() != reflect.Map { + panic(fmt.Sprintf("sync.MapKeyHasher: m is %v, not map", rtyp)) + } + mtyp := *(**maptype)(unsafe.Pointer(&m)) + return mtyp.hasher +} + +type maptype struct { + size uintptr + ptrdata uintptr + hash uint32 + tflag uint8 + align uint8 + fieldAlign uint8 + kind uint8 + equal func(unsafe.Pointer, unsafe.Pointer) bool + gcdata *byte + str int32 + ptrToThis int32 + key unsafe.Pointer + elem unsafe.Pointer + bucket unsafe.Pointer + hasher func(unsafe.Pointer, uintptr) uintptr + // more fields +} + // These functions are only used within the sync package. //go:linkname semacquire sync.runtime_Semacquire diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index c53698a6a..b14df9e09 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -130,6 +130,10 @@ type SocketOptions struct { // corkOptionEnabled is used to specify if data should be held until segments // are full by the TCP transport protocol. corkOptionEnabled uint32 + + // receiveOriginalDstAddress is used to specify if the original destination of + // the incoming packet should be returned as an ancillary message. + receiveOriginalDstAddress uint32 } // InitHandler initializes the handler. This must be called before using the @@ -302,3 +306,13 @@ func (so *SocketOptions) SetCorkOption(v bool) { storeAtomicBool(&so.corkOptionEnabled, v) so.handler.OnCorkOptionSet(v) } + +// GetReceiveOriginalDstAddress gets value for IP(V6)_RECVORIGDSTADDR option. +func (so *SocketOptions) GetReceiveOriginalDstAddress() bool { + return atomic.LoadUint32(&so.receiveOriginalDstAddress) != 0 +} + +// SetReceiveOriginalDstAddress sets value for IP(V6)_RECVORIGDSTADDR option. +func (so *SocketOptions) SetReceiveOriginalDstAddress(v bool) { + storeAtomicBool(&so.receiveOriginalDstAddress, v) +} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 23e597365..0d2dad881 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -492,6 +492,14 @@ type ControlMessages struct { // PacketInfo holds interface and address data on an incoming packet. PacketInfo IPPacketInfo + + // HasOriginalDestinationAddress indicates whether OriginalDstAddress is + // set. + HasOriginalDstAddress bool + + // OriginalDestinationAddress holds the original destination address + // and port of the incoming packet. + OriginalDstAddress FullAddress } // PacketOwner is used to get UID and GID of the packet. diff --git a/pkg/tcpip/tcpip_state_autogen.go b/pkg/tcpip/tcpip_state_autogen.go index 94e40e30e..36ef19575 100644 --- a/pkg/tcpip/tcpip_state_autogen.go +++ b/pkg/tcpip/tcpip_state_autogen.go @@ -28,6 +28,7 @@ func (so *SocketOptions) StateFields() []string { "quickAckEnabled", "delayOptionEnabled", "corkOptionEnabled", + "receiveOriginalDstAddress", } } @@ -51,6 +52,7 @@ func (so *SocketOptions) StateSave(stateSinkObject state.Sink) { stateSinkObject.Save(13, &so.quickAckEnabled) stateSinkObject.Save(14, &so.delayOptionEnabled) stateSinkObject.Save(15, &so.corkOptionEnabled) + stateSinkObject.Save(16, &so.receiveOriginalDstAddress) } func (so *SocketOptions) afterLoad() {} @@ -72,6 +74,7 @@ func (so *SocketOptions) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(13, &so.quickAckEnabled) stateSourceObject.Load(14, &so.delayOptionEnabled) stateSourceObject.Load(15, &so.corkOptionEnabled) + stateSourceObject.Load(16, &so.receiveOriginalDstAddress) } func (e *Error) StateTypeName() string { @@ -145,6 +148,8 @@ func (c *ControlMessages) StateFields() []string { "TClass", "HasIPPacketInfo", "PacketInfo", + "HasOriginalDstAddress", + "OriginalDstAddress", } } @@ -162,6 +167,8 @@ func (c *ControlMessages) StateSave(stateSinkObject state.Sink) { stateSinkObject.Save(7, &c.TClass) stateSinkObject.Save(8, &c.HasIPPacketInfo) stateSinkObject.Save(9, &c.PacketInfo) + stateSinkObject.Save(10, &c.HasOriginalDstAddress) + stateSinkObject.Save(11, &c.OriginalDstAddress) } func (c *ControlMessages) afterLoad() {} @@ -177,6 +184,8 @@ func (c *ControlMessages) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(7, &c.TClass) stateSourceObject.Load(8, &c.HasIPPacketInfo) stateSourceObject.Load(9, &c.PacketInfo) + stateSourceObject.Load(10, &c.HasOriginalDstAddress) + stateSourceObject.Load(11, &c.OriginalDstAddress) } func (l *LinkPacketInfo) StateTypeName() string { diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index ee1bb29f8..04596183e 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -30,10 +30,11 @@ import ( // +stateify savable type udpPacket struct { udpPacketEntry - senderAddress tcpip.FullAddress - packetInfo tcpip.IPPacketInfo - data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - timestamp int64 + senderAddress tcpip.FullAddress + destinationAddress tcpip.FullAddress + packetInfo tcpip.IPPacketInfo + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + timestamp int64 // tos stores either the receiveTOS or receiveTClass value. tos uint8 } @@ -323,6 +324,10 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess cm.HasIPPacketInfo = true cm.PacketInfo = p.packetInfo } + if e.ops.GetReceiveOriginalDstAddress() { + cm.HasOriginalDstAddress = true + cm.OriginalDstAddress = p.destinationAddress + } return p.data.ToView(), cm, nil } @@ -1314,6 +1319,11 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB Addr: id.RemoteAddress, Port: hdr.SourcePort(), }, + destinationAddress: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.LocalAddress, + Port: header.UDP(hdr).DestinationPort(), + }, } packet.data = pkt.Data e.rcvList.PushBack(packet) diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go index dc6afdff9..451d8eff0 100644 --- a/pkg/tcpip/transport/udp/udp_state_autogen.go +++ b/pkg/tcpip/transport/udp/udp_state_autogen.go @@ -15,6 +15,7 @@ func (u *udpPacket) StateFields() []string { return []string{ "udpPacketEntry", "senderAddress", + "destinationAddress", "packetInfo", "data", "timestamp", @@ -27,12 +28,13 @@ func (u *udpPacket) beforeSave() {} func (u *udpPacket) StateSave(stateSinkObject state.Sink) { u.beforeSave() var dataValue buffer.VectorisedView = u.saveData() - stateSinkObject.SaveValue(3, dataValue) + stateSinkObject.SaveValue(4, dataValue) stateSinkObject.Save(0, &u.udpPacketEntry) stateSinkObject.Save(1, &u.senderAddress) - stateSinkObject.Save(2, &u.packetInfo) - stateSinkObject.Save(4, &u.timestamp) - stateSinkObject.Save(5, &u.tos) + stateSinkObject.Save(2, &u.destinationAddress) + stateSinkObject.Save(3, &u.packetInfo) + stateSinkObject.Save(5, &u.timestamp) + stateSinkObject.Save(6, &u.tos) } func (u *udpPacket) afterLoad() {} @@ -40,10 +42,11 @@ func (u *udpPacket) afterLoad() {} func (u *udpPacket) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(0, &u.udpPacketEntry) stateSourceObject.Load(1, &u.senderAddress) - stateSourceObject.Load(2, &u.packetInfo) - stateSourceObject.Load(4, &u.timestamp) - stateSourceObject.Load(5, &u.tos) - stateSourceObject.LoadValue(3, new(buffer.VectorisedView), func(y interface{}) { u.loadData(y.(buffer.VectorisedView)) }) + stateSourceObject.Load(2, &u.destinationAddress) + stateSourceObject.Load(3, &u.packetInfo) + stateSourceObject.Load(5, &u.timestamp) + stateSourceObject.Load(6, &u.tos) + stateSourceObject.LoadValue(4, new(buffer.VectorisedView), func(y interface{}) { u.loadData(y.(buffer.VectorisedView)) }) } func (e *endpoint) StateTypeName() string { |