summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/socket/control/control.go19
-rw-r--r--pkg/sentry/socket/hostinet/socket.go2
-rw-r--r--pkg/sentry/socket/netstack/netstack.go57
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go3
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go3
-rw-r--r--pkg/sentry/vfs/mount_unsafe.go36
-rw-r--r--pkg/sentry/vfs/save_restore.go6
-rw-r--r--pkg/sentry/vfs/vfs_unsafe_state_autogen.go37
-rw-r--r--pkg/sync/runtime_unsafe.go53
-rw-r--r--pkg/tcpip/socketops.go14
-rw-r--r--pkg/tcpip/tcpip.go8
-rw-r--r--pkg/tcpip/tcpip_state_autogen.go9
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go18
-rw-r--r--pkg/tcpip/transport/udp/udp_state_autogen.go19
14 files changed, 198 insertions, 86 deletions
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 70ccf77a7..c284efde5 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -359,13 +359,26 @@ func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte)
)
}
+// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message.
+func PackOriginalDstAddress(t *kernel.Task, family int, originalDstAddress tcpip.FullAddress, buf []byte) []byte {
+ p, _ := socket.ConvertAddress(family, originalDstAddress)
+ level := uint32(linux.SOL_IP)
+ optType := uint32(linux.IP_RECVORIGDSTADDR)
+ if family == linux.AF_INET6 {
+ level = linux.SOL_IPV6
+ optType = linux.IPV6_RECVORIGDSTADDR
+ }
+ return putCmsgStruct(
+ buf, level, optType, t.Arch().Width(), p)
+}
+
// PackControlMessages packs control messages into the given buffer.
//
// We skip control messages specific to Unix domain sockets.
//
// Note that some control messages may be truncated if they do not fit under
// the capacity of buf.
-func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte {
+func PackControlMessages(t *kernel.Task, family int, cmsgs socket.ControlMessages, buf []byte) []byte {
if cmsgs.IP.HasTimestamp {
buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf)
}
@@ -387,6 +400,10 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt
buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf)
}
+ if cmsgs.IP.HasOriginalDstAddress {
+ buf = PackOriginalDstAddress(t, family, cmsgs.IP.OriginalDstAddress, buf)
+ }
+
return buf
}
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 7d3c4a01c..a35850a8f 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -551,7 +551,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
}
controlBuf := make([]byte, 0, space)
// PackControlMessages will append up to space bytes to controlBuf.
- controlBuf = control.PackControlMessages(t, controlMessages, controlBuf)
+ controlBuf = control.PackControlMessages(t, s.family, controlMessages, controlBuf)
sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) {
// Refuse to do anything if any part of src.Addrs was unusable.
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 3cc0d4f0f..37c3faa57 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1418,6 +1418,14 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass()))
return &v, nil
+ case linux.IPV6_RECVORIGDSTADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
+ return &v, nil
+
case linux.IP6T_ORIGINAL_DST:
if outLen < int(binary.Size(linux.SockAddrInet6{})) {
return nil, syserr.ErrInvalidArgument
@@ -1599,6 +1607,14 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded()))
return &v, nil
+ case linux.IP_RECVORIGDSTADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
+ return &v, nil
+
case linux.SO_ORIGINAL_DST:
if outLen < int(binary.Size(linux.SockAddrInet{})) {
return nil, syserr.ErrInvalidArgument
@@ -2094,6 +2110,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
t.Kernel().EmitUnimplementedEvent(t)
+ case linux.IPV6_RECVORIGDSTADDR:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+
+ ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
+ return nil
+
case linux.IPV6_TCLASS:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -2325,6 +2350,18 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
ep.SocketOptions().SetHeaderIncluded(v != 0)
return nil
+ case linux.IP_RECVORIGDSTADDR:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
+ return nil
+
case linux.IPT_SO_SET_REPLACE:
if len(optVal) < linux.SizeOfIPTReplace {
return syserr.ErrInvalidArgument
@@ -2363,7 +2400,6 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
linux.IP_RECVERR,
linux.IP_RECVFRAGSIZE,
linux.IP_RECVOPTS,
- linux.IP_RECVORIGDSTADDR,
linux.IP_RECVTTL,
linux.IP_RETOPTS,
linux.IP_TRANSPARENT,
@@ -2441,7 +2477,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
linux.IPV6_RECVFRAGSIZE,
linux.IPV6_RECVHOPLIMIT,
linux.IPV6_RECVHOPOPTS,
- linux.IPV6_RECVORIGDSTADDR,
linux.IPV6_RECVPATHMTU,
linux.IPV6_RECVPKTINFO,
linux.IPV6_RECVRTHDR,
@@ -2746,14 +2781,16 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
return socket.ControlMessages{
IP: tcpip.ControlMessages{
- HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
- Timestamp: s.readCM.Timestamp,
- HasTOS: s.readCM.HasTOS,
- TOS: s.readCM.TOS,
- HasTClass: s.readCM.HasTClass,
- TClass: s.readCM.TClass,
- HasIPPacketInfo: s.readCM.HasIPPacketInfo,
- PacketInfo: s.readCM.PacketInfo,
+ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
+ Timestamp: s.readCM.Timestamp,
+ HasTOS: s.readCM.HasTOS,
+ TOS: s.readCM.TOS,
+ HasTClass: s.readCM.HasTClass,
+ TClass: s.readCM.TClass,
+ HasIPPacketInfo: s.readCM.HasIPPacketInfo,
+ PacketInfo: s.readCM.PacketInfo,
+ HasOriginalDstAddress: s.readCM.HasOriginalDstAddress,
+ OriginalDstAddress: s.readCM.OriginalDstAddress,
},
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 9cd052c3d..754a5df24 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -784,8 +784,9 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
}
defer cms.Release(t)
+ family, _, _ := s.Type()
controlData := make([]byte, 0, msg.ControlLen)
- controlData = control.PackControlMessages(t, cms, controlData)
+ controlData = control.PackControlMessages(t, family, cms, controlData)
if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() {
creds, _ := cms.Unix.Credentials.(control.SCMCredentials)
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
index 7b33b3f59..c40892b9f 100644
--- a/pkg/sentry/syscalls/linux/vfs2/socket.go
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -787,8 +787,9 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla
}
defer cms.Release(t)
+ family, _, _ := s.Type()
controlData := make([]byte, 0, msg.ControlLen)
- controlData = control.PackControlMessages(t, cms, controlData)
+ controlData = control.PackControlMessages(t, family, cms, controlData)
if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() {
creds, _ := cms.Unix.Credentials.(control.SCMCredentials)
diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go
index cb48c37a1..0df023713 100644
--- a/pkg/sentry/vfs/mount_unsafe.go
+++ b/pkg/sentry/vfs/mount_unsafe.go
@@ -12,11 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build go1.12
-// +build !go1.17
-
-// Check go:linkname function signatures when updating Go version.
-
package vfs
import (
@@ -41,6 +36,15 @@ type mountKey struct {
point unsafe.Pointer // *Dentry
}
+var (
+ mountKeyHasher = sync.MapKeyHasher(map[mountKey]struct{}(nil))
+ mountKeySeed = sync.RandUintptr()
+)
+
+func (k *mountKey) hash() uintptr {
+ return mountKeyHasher(gohacks.Noescape(unsafe.Pointer(k)), mountKeySeed)
+}
+
func (mnt *Mount) parent() *Mount {
return (*Mount)(atomic.LoadPointer(&mnt.key.parent))
}
@@ -56,23 +60,17 @@ func (mnt *Mount) getKey() VirtualDentry {
}
}
-func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() }
-
// Invariant: mnt.key.parent == nil. vd.Ok().
func (mnt *Mount) setKey(vd VirtualDentry) {
atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount))
atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry))
}
-func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) }
-
// mountTable maps (mount parent, mount point) pairs to mounts. It supports
// efficient concurrent lookup, even in the presence of concurrent mutators
// (provided mutation is sufficiently uncommon).
//
// mountTable.Init() must be called on new mountTables before use.
-//
-// +stateify savable
type mountTable struct {
// mountTable is implemented as a seqcount-protected hash table that
// resolves collisions with linear probing, featuring Robin Hood insertion
@@ -84,8 +82,7 @@ type mountTable struct {
// intrinsics and inline assembly, limiting the performance of this
// approach.)
- seq sync.SeqCount `state:"nosave"`
- seed uint32 // for hashing keys
+ seq sync.SeqCount `state:"nosave"`
// size holds both length (number of elements) and capacity (number of
// slots): capacity is stored as its base-2 log (referred to as order) in
@@ -150,7 +147,6 @@ func init() {
// Init must be called exactly once on each mountTable before use.
func (mt *mountTable) Init() {
- mt.seed = rand32()
mt.size = mtInitOrder
mt.slots = newMountTableSlots(mtInitCap)
}
@@ -167,7 +163,7 @@ func newMountTableSlots(cap uintptr) unsafe.Pointer {
// Lookup may be called even if there are concurrent mutators of mt.
func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount {
key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)}
- hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes)
+ hash := key.hash()
loop:
for {
@@ -247,7 +243,7 @@ func (mt *mountTable) Insert(mount *Mount) {
// * mt.seq must be in a writer critical section.
// * mt must not already contain a Mount with the same mount point and parent.
func (mt *mountTable) insertSeqed(mount *Mount) {
- hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes)
+ hash := mount.key.hash()
// We're under the maximum load factor if:
//
@@ -346,7 +342,7 @@ func (mt *mountTable) Remove(mount *Mount) {
// * mt.seq must be in a writer critical section.
// * mt must contain mount.
func (mt *mountTable) removeSeqed(mount *Mount) {
- hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes)
+ hash := mount.key.hash()
tcap := uintptr(1) << (mt.size & mtSizeOrderMask)
mask := tcap - 1
slots := mt.slots
@@ -386,9 +382,3 @@ func (mt *mountTable) removeSeqed(mount *Mount) {
off = (off + mountSlotBytes) & offmask
}
}
-
-//go:linkname memhash runtime.memhash
-func memhash(p unsafe.Pointer, seed, s uintptr) uintptr
-
-//go:linkname rand32 runtime.fastrand
-func rand32() uint32
diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go
index 8f070ed53..8998a82dd 100644
--- a/pkg/sentry/vfs/save_restore.go
+++ b/pkg/sentry/vfs/save_restore.go
@@ -101,6 +101,9 @@ func (vfs *VirtualFilesystem) saveMounts() []*Mount {
return mounts
}
+// saveKey is called by stateify.
+func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() }
+
// loadMounts is called by stateify.
func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) {
if mounts == nil {
@@ -112,6 +115,9 @@ func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) {
}
}
+// loadKey is called by stateify.
+func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) }
+
func (mnt *Mount) afterLoad() {
if atomic.LoadInt64(&mnt.refs) != 0 {
refsvfs2.Register(mnt)
diff --git a/pkg/sentry/vfs/vfs_unsafe_state_autogen.go b/pkg/sentry/vfs/vfs_unsafe_state_autogen.go
index d23733ee7..20f06c953 100644
--- a/pkg/sentry/vfs/vfs_unsafe_state_autogen.go
+++ b/pkg/sentry/vfs/vfs_unsafe_state_autogen.go
@@ -1,40 +1,3 @@
// automatically generated by stateify.
-// +build go1.12
-// +build !go1.17
-
package vfs
-
-import (
- "gvisor.dev/gvisor/pkg/state"
-)
-
-func (mt *mountTable) StateTypeName() string {
- return "pkg/sentry/vfs.mountTable"
-}
-
-func (mt *mountTable) StateFields() []string {
- return []string{
- "seed",
- "size",
- }
-}
-
-func (mt *mountTable) beforeSave() {}
-
-func (mt *mountTable) StateSave(stateSinkObject state.Sink) {
- mt.beforeSave()
- stateSinkObject.Save(0, &mt.seed)
- stateSinkObject.Save(1, &mt.size)
-}
-
-func (mt *mountTable) afterLoad() {}
-
-func (mt *mountTable) StateLoad(stateSourceObject state.Source) {
- stateSourceObject.Load(0, &mt.seed)
- stateSourceObject.Load(1, &mt.size)
-}
-
-func init() {
- state.Register((*mountTable)(nil))
-}
diff --git a/pkg/sync/runtime_unsafe.go b/pkg/sync/runtime_unsafe.go
index 7ad6a4434..e925e2e5b 100644
--- a/pkg/sync/runtime_unsafe.go
+++ b/pkg/sync/runtime_unsafe.go
@@ -11,6 +11,8 @@
package sync
import (
+ "fmt"
+ "reflect"
"unsafe"
)
@@ -61,6 +63,57 @@ const (
TraceEvGoBlockSelect byte = 24
)
+// Rand32 returns a non-cryptographically-secure random uint32.
+func Rand32() uint32 {
+ return fastrand()
+}
+
+// Rand64 returns a non-cryptographically-secure random uint64.
+func Rand64() uint64 {
+ return uint64(fastrand())<<32 | uint64(fastrand())
+}
+
+//go:linkname fastrand runtime.fastrand
+func fastrand() uint32
+
+// RandUintptr returns a non-cryptographically-secure random uintptr.
+func RandUintptr() uintptr {
+ if unsafe.Sizeof(uintptr(0)) == 4 {
+ return uintptr(Rand32())
+ }
+ return uintptr(Rand64())
+}
+
+// MapKeyHasher returns a hash function for pointers of m's key type.
+//
+// Preconditions: m must be a map.
+func MapKeyHasher(m interface{}) func(unsafe.Pointer, uintptr) uintptr {
+ if rtyp := reflect.TypeOf(m); rtyp.Kind() != reflect.Map {
+ panic(fmt.Sprintf("sync.MapKeyHasher: m is %v, not map", rtyp))
+ }
+ mtyp := *(**maptype)(unsafe.Pointer(&m))
+ return mtyp.hasher
+}
+
+type maptype struct {
+ size uintptr
+ ptrdata uintptr
+ hash uint32
+ tflag uint8
+ align uint8
+ fieldAlign uint8
+ kind uint8
+ equal func(unsafe.Pointer, unsafe.Pointer) bool
+ gcdata *byte
+ str int32
+ ptrToThis int32
+ key unsafe.Pointer
+ elem unsafe.Pointer
+ bucket unsafe.Pointer
+ hasher func(unsafe.Pointer, uintptr) uintptr
+ // more fields
+}
+
// These functions are only used within the sync package.
//go:linkname semacquire sync.runtime_Semacquire
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index c53698a6a..b14df9e09 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -130,6 +130,10 @@ type SocketOptions struct {
// corkOptionEnabled is used to specify if data should be held until segments
// are full by the TCP transport protocol.
corkOptionEnabled uint32
+
+ // receiveOriginalDstAddress is used to specify if the original destination of
+ // the incoming packet should be returned as an ancillary message.
+ receiveOriginalDstAddress uint32
}
// InitHandler initializes the handler. This must be called before using the
@@ -302,3 +306,13 @@ func (so *SocketOptions) SetCorkOption(v bool) {
storeAtomicBool(&so.corkOptionEnabled, v)
so.handler.OnCorkOptionSet(v)
}
+
+// GetReceiveOriginalDstAddress gets value for IP(V6)_RECVORIGDSTADDR option.
+func (so *SocketOptions) GetReceiveOriginalDstAddress() bool {
+ return atomic.LoadUint32(&so.receiveOriginalDstAddress) != 0
+}
+
+// SetReceiveOriginalDstAddress sets value for IP(V6)_RECVORIGDSTADDR option.
+func (so *SocketOptions) SetReceiveOriginalDstAddress(v bool) {
+ storeAtomicBool(&so.receiveOriginalDstAddress, v)
+}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 23e597365..0d2dad881 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -492,6 +492,14 @@ type ControlMessages struct {
// PacketInfo holds interface and address data on an incoming packet.
PacketInfo IPPacketInfo
+
+ // HasOriginalDestinationAddress indicates whether OriginalDstAddress is
+ // set.
+ HasOriginalDstAddress bool
+
+ // OriginalDestinationAddress holds the original destination address
+ // and port of the incoming packet.
+ OriginalDstAddress FullAddress
}
// PacketOwner is used to get UID and GID of the packet.
diff --git a/pkg/tcpip/tcpip_state_autogen.go b/pkg/tcpip/tcpip_state_autogen.go
index 94e40e30e..36ef19575 100644
--- a/pkg/tcpip/tcpip_state_autogen.go
+++ b/pkg/tcpip/tcpip_state_autogen.go
@@ -28,6 +28,7 @@ func (so *SocketOptions) StateFields() []string {
"quickAckEnabled",
"delayOptionEnabled",
"corkOptionEnabled",
+ "receiveOriginalDstAddress",
}
}
@@ -51,6 +52,7 @@ func (so *SocketOptions) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(13, &so.quickAckEnabled)
stateSinkObject.Save(14, &so.delayOptionEnabled)
stateSinkObject.Save(15, &so.corkOptionEnabled)
+ stateSinkObject.Save(16, &so.receiveOriginalDstAddress)
}
func (so *SocketOptions) afterLoad() {}
@@ -72,6 +74,7 @@ func (so *SocketOptions) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(13, &so.quickAckEnabled)
stateSourceObject.Load(14, &so.delayOptionEnabled)
stateSourceObject.Load(15, &so.corkOptionEnabled)
+ stateSourceObject.Load(16, &so.receiveOriginalDstAddress)
}
func (e *Error) StateTypeName() string {
@@ -145,6 +148,8 @@ func (c *ControlMessages) StateFields() []string {
"TClass",
"HasIPPacketInfo",
"PacketInfo",
+ "HasOriginalDstAddress",
+ "OriginalDstAddress",
}
}
@@ -162,6 +167,8 @@ func (c *ControlMessages) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(7, &c.TClass)
stateSinkObject.Save(8, &c.HasIPPacketInfo)
stateSinkObject.Save(9, &c.PacketInfo)
+ stateSinkObject.Save(10, &c.HasOriginalDstAddress)
+ stateSinkObject.Save(11, &c.OriginalDstAddress)
}
func (c *ControlMessages) afterLoad() {}
@@ -177,6 +184,8 @@ func (c *ControlMessages) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(7, &c.TClass)
stateSourceObject.Load(8, &c.HasIPPacketInfo)
stateSourceObject.Load(9, &c.PacketInfo)
+ stateSourceObject.Load(10, &c.HasOriginalDstAddress)
+ stateSourceObject.Load(11, &c.OriginalDstAddress)
}
func (l *LinkPacketInfo) StateTypeName() string {
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index ee1bb29f8..04596183e 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -30,10 +30,11 @@ import (
// +stateify savable
type udpPacket struct {
udpPacketEntry
- senderAddress tcpip.FullAddress
- packetInfo tcpip.IPPacketInfo
- data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- timestamp int64
+ senderAddress tcpip.FullAddress
+ destinationAddress tcpip.FullAddress
+ packetInfo tcpip.IPPacketInfo
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ timestamp int64
// tos stores either the receiveTOS or receiveTClass value.
tos uint8
}
@@ -323,6 +324,10 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
cm.HasIPPacketInfo = true
cm.PacketInfo = p.packetInfo
}
+ if e.ops.GetReceiveOriginalDstAddress() {
+ cm.HasOriginalDstAddress = true
+ cm.OriginalDstAddress = p.destinationAddress
+ }
return p.data.ToView(), cm, nil
}
@@ -1314,6 +1319,11 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
Addr: id.RemoteAddress,
Port: hdr.SourcePort(),
},
+ destinationAddress: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.LocalAddress,
+ Port: header.UDP(hdr).DestinationPort(),
+ },
}
packet.data = pkt.Data
e.rcvList.PushBack(packet)
diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go
index dc6afdff9..451d8eff0 100644
--- a/pkg/tcpip/transport/udp/udp_state_autogen.go
+++ b/pkg/tcpip/transport/udp/udp_state_autogen.go
@@ -15,6 +15,7 @@ func (u *udpPacket) StateFields() []string {
return []string{
"udpPacketEntry",
"senderAddress",
+ "destinationAddress",
"packetInfo",
"data",
"timestamp",
@@ -27,12 +28,13 @@ func (u *udpPacket) beforeSave() {}
func (u *udpPacket) StateSave(stateSinkObject state.Sink) {
u.beforeSave()
var dataValue buffer.VectorisedView = u.saveData()
- stateSinkObject.SaveValue(3, dataValue)
+ stateSinkObject.SaveValue(4, dataValue)
stateSinkObject.Save(0, &u.udpPacketEntry)
stateSinkObject.Save(1, &u.senderAddress)
- stateSinkObject.Save(2, &u.packetInfo)
- stateSinkObject.Save(4, &u.timestamp)
- stateSinkObject.Save(5, &u.tos)
+ stateSinkObject.Save(2, &u.destinationAddress)
+ stateSinkObject.Save(3, &u.packetInfo)
+ stateSinkObject.Save(5, &u.timestamp)
+ stateSinkObject.Save(6, &u.tos)
}
func (u *udpPacket) afterLoad() {}
@@ -40,10 +42,11 @@ func (u *udpPacket) afterLoad() {}
func (u *udpPacket) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(0, &u.udpPacketEntry)
stateSourceObject.Load(1, &u.senderAddress)
- stateSourceObject.Load(2, &u.packetInfo)
- stateSourceObject.Load(4, &u.timestamp)
- stateSourceObject.Load(5, &u.tos)
- stateSourceObject.LoadValue(3, new(buffer.VectorisedView), func(y interface{}) { u.loadData(y.(buffer.VectorisedView)) })
+ stateSourceObject.Load(2, &u.destinationAddress)
+ stateSourceObject.Load(3, &u.packetInfo)
+ stateSourceObject.Load(5, &u.timestamp)
+ stateSourceObject.Load(6, &u.tos)
+ stateSourceObject.LoadValue(4, new(buffer.VectorisedView), func(y interface{}) { u.loadData(y.(buffer.VectorisedView)) })
}
func (e *endpoint) StateTypeName() string {