From 92ca72ecb73d91e9def31e7f9835adf7a50b3d65 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 9 Dec 2020 14:54:43 -0800 Subject: Add support for IP_RECVORIGDSTADDR IP option. Fixes #5004 PiperOrigin-RevId: 346643745 --- pkg/sentry/socket/control/control.go | 19 ++++++- pkg/sentry/socket/hostinet/socket.go | 2 +- pkg/sentry/socket/netstack/netstack.go | 57 +++++++++++++++++---- pkg/sentry/syscalls/linux/sys_socket.go | 3 +- pkg/sentry/syscalls/linux/vfs2/socket.go | 3 +- pkg/tcpip/checker/checker.go | 13 +++++ pkg/tcpip/socketops.go | 14 +++++ pkg/tcpip/tcpip.go | 8 +++ pkg/tcpip/transport/udp/endpoint.go | 18 +++++-- pkg/tcpip/transport/udp/udp_test.go | 87 ++++++++++++++++++++++++++++++++ 10 files changed, 206 insertions(+), 18 deletions(-) (limited to 'pkg') diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 70ccf77a7..c284efde5 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -359,13 +359,26 @@ func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) ) } +// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message. +func PackOriginalDstAddress(t *kernel.Task, family int, originalDstAddress tcpip.FullAddress, buf []byte) []byte { + p, _ := socket.ConvertAddress(family, originalDstAddress) + level := uint32(linux.SOL_IP) + optType := uint32(linux.IP_RECVORIGDSTADDR) + if family == linux.AF_INET6 { + level = linux.SOL_IPV6 + optType = linux.IPV6_RECVORIGDSTADDR + } + return putCmsgStruct( + buf, level, optType, t.Arch().Width(), p) +} + // PackControlMessages packs control messages into the given buffer. // // We skip control messages specific to Unix domain sockets. // // Note that some control messages may be truncated if they do not fit under // the capacity of buf. -func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte { +func PackControlMessages(t *kernel.Task, family int, cmsgs socket.ControlMessages, buf []byte) []byte { if cmsgs.IP.HasTimestamp { buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf) } @@ -387,6 +400,10 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf) } + if cmsgs.IP.HasOriginalDstAddress { + buf = PackOriginalDstAddress(t, family, cmsgs.IP.OriginalDstAddress, buf) + } + return buf } diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 7d3c4a01c..a35850a8f 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -551,7 +551,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b } controlBuf := make([]byte, 0, space) // PackControlMessages will append up to space bytes to controlBuf. - controlBuf = control.PackControlMessages(t, controlMessages, controlBuf) + controlBuf = control.PackControlMessages(t, s.family, controlMessages, controlBuf) sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) { // Refuse to do anything if any part of src.Addrs was unusable. diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 3cc0d4f0f..37c3faa57 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1418,6 +1418,14 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass())) return &v, nil + case linux.IPV6_RECVORIGDSTADDR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress())) + return &v, nil + case linux.IP6T_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet6{})) { return nil, syserr.ErrInvalidArgument @@ -1599,6 +1607,14 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded())) return &v, nil + case linux.IP_RECVORIGDSTADDR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress())) + return &v, nil + case linux.SO_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet{})) { return nil, syserr.ErrInvalidArgument @@ -2094,6 +2110,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name t.Kernel().EmitUnimplementedEvent(t) + case linux.IPV6_RECVORIGDSTADDR: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := int32(usermem.ByteOrder.Uint32(optVal)) + + ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) + return nil + case linux.IPV6_TCLASS: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -2325,6 +2350,18 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in ep.SocketOptions().SetHeaderIncluded(v != 0) return nil + case linux.IP_RECVORIGDSTADDR: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + + ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) + return nil + case linux.IPT_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIPTReplace { return syserr.ErrInvalidArgument @@ -2363,7 +2400,6 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in linux.IP_RECVERR, linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, - linux.IP_RECVORIGDSTADDR, linux.IP_RECVTTL, linux.IP_RETOPTS, linux.IP_TRANSPARENT, @@ -2441,7 +2477,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) { linux.IPV6_RECVFRAGSIZE, linux.IPV6_RECVHOPLIMIT, linux.IPV6_RECVHOPOPTS, - linux.IPV6_RECVORIGDSTADDR, linux.IPV6_RECVPATHMTU, linux.IPV6_RECVPKTINFO, linux.IPV6_RECVRTHDR, @@ -2746,14 +2781,16 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq func (s *socketOpsCommon) controlMessages() socket.ControlMessages { return socket.ControlMessages{ IP: tcpip.ControlMessages{ - HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, - Timestamp: s.readCM.Timestamp, - HasTOS: s.readCM.HasTOS, - TOS: s.readCM.TOS, - HasTClass: s.readCM.HasTClass, - TClass: s.readCM.TClass, - HasIPPacketInfo: s.readCM.HasIPPacketInfo, - PacketInfo: s.readCM.PacketInfo, + HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, + Timestamp: s.readCM.Timestamp, + HasTOS: s.readCM.HasTOS, + TOS: s.readCM.TOS, + HasTClass: s.readCM.HasTClass, + TClass: s.readCM.TClass, + HasIPPacketInfo: s.readCM.HasIPPacketInfo, + PacketInfo: s.readCM.PacketInfo, + HasOriginalDstAddress: s.readCM.HasOriginalDstAddress, + OriginalDstAddress: s.readCM.OriginalDstAddress, }, } } diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 9cd052c3d..754a5df24 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -784,8 +784,9 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i } defer cms.Release(t) + family, _, _ := s.Type() controlData := make([]byte, 0, msg.ControlLen) - controlData = control.PackControlMessages(t, cms, controlData) + controlData = control.PackControlMessages(t, family, cms, controlData) if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() { creds, _ := cms.Unix.Credentials.(control.SCMCredentials) diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 7b33b3f59..c40892b9f 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -787,8 +787,9 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla } defer cms.Release(t) + family, _, _ := s.Type() controlData := make([]byte, 0, msg.ControlLen) - controlData = control.PackControlMessages(t, cms, controlData) + controlData = control.PackControlMessages(t, family, cms, controlData) if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() { creds, _ := cms.Unix.Credentials.(control.SCMCredentials) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index d3ae56ac6..8830c45d7 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -321,6 +321,19 @@ func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker { } } +// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress +// field in ControlMessages. +func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasOriginalDstAddress { + t.Errorf("got cm.HasOriginalDstAddress = %t, want = true", cm.HasOriginalDstAddress) + } else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" { + t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff) + } + } +} + // TOS creates a checker that checks the TOS field. func TOS(tos uint8, label uint32) NetworkChecker { return func(t *testing.T, h []header.Network) { diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index c53698a6a..b14df9e09 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -130,6 +130,10 @@ type SocketOptions struct { // corkOptionEnabled is used to specify if data should be held until segments // are full by the TCP transport protocol. corkOptionEnabled uint32 + + // receiveOriginalDstAddress is used to specify if the original destination of + // the incoming packet should be returned as an ancillary message. + receiveOriginalDstAddress uint32 } // InitHandler initializes the handler. This must be called before using the @@ -302,3 +306,13 @@ func (so *SocketOptions) SetCorkOption(v bool) { storeAtomicBool(&so.corkOptionEnabled, v) so.handler.OnCorkOptionSet(v) } + +// GetReceiveOriginalDstAddress gets value for IP(V6)_RECVORIGDSTADDR option. +func (so *SocketOptions) GetReceiveOriginalDstAddress() bool { + return atomic.LoadUint32(&so.receiveOriginalDstAddress) != 0 +} + +// SetReceiveOriginalDstAddress sets value for IP(V6)_RECVORIGDSTADDR option. +func (so *SocketOptions) SetReceiveOriginalDstAddress(v bool) { + storeAtomicBool(&so.receiveOriginalDstAddress, v) +} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 23e597365..0d2dad881 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -492,6 +492,14 @@ type ControlMessages struct { // PacketInfo holds interface and address data on an incoming packet. PacketInfo IPPacketInfo + + // HasOriginalDestinationAddress indicates whether OriginalDstAddress is + // set. + HasOriginalDstAddress bool + + // OriginalDestinationAddress holds the original destination address + // and port of the incoming packet. + OriginalDstAddress FullAddress } // PacketOwner is used to get UID and GID of the packet. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index ee1bb29f8..04596183e 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -30,10 +30,11 @@ import ( // +stateify savable type udpPacket struct { udpPacketEntry - senderAddress tcpip.FullAddress - packetInfo tcpip.IPPacketInfo - data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - timestamp int64 + senderAddress tcpip.FullAddress + destinationAddress tcpip.FullAddress + packetInfo tcpip.IPPacketInfo + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + timestamp int64 // tos stores either the receiveTOS or receiveTClass value. tos uint8 } @@ -323,6 +324,10 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess cm.HasIPPacketInfo = true cm.PacketInfo = p.packetInfo } + if e.ops.GetReceiveOriginalDstAddress() { + cm.HasOriginalDstAddress = true + cm.OriginalDstAddress = p.destinationAddress + } return p.data.ToView(), cm, nil } @@ -1314,6 +1319,11 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB Addr: id.RemoteAddress, Port: hdr.SourcePort(), }, + destinationAddress: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.LocalAddress, + Port: header.UDP(hdr).DestinationPort(), + }, } packet.data = pkt.Data e.rcvList.PushBack(packet) diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index b0b9b2773..cca865b34 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1428,6 +1428,93 @@ func TestReadIPPacketInfo(t *testing.T) { } } +func TestReadRecvOriginalDstAddr(t *testing.T) { + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + flow testFlow + expectedOriginalDstAddr tcpip.FullAddress + }{ + { + name: "IPv4 unicast", + proto: header.IPv4ProtocolNumber, + flow: unicastV4, + expectedOriginalDstAddr: tcpip.FullAddress{1, stackAddr, stackPort}, + }, + { + name: "IPv4 multicast", + proto: header.IPv4ProtocolNumber, + flow: multicastV4, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedOriginalDstAddr: tcpip.FullAddress{1, multicastAddr, stackPort}, + }, + { + name: "IPv4 broadcast", + proto: header.IPv4ProtocolNumber, + flow: broadcast, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedOriginalDstAddr: tcpip.FullAddress{1, broadcastAddr, stackPort}, + }, + { + name: "IPv6 unicast", + proto: header.IPv6ProtocolNumber, + flow: unicastV6, + expectedOriginalDstAddr: tcpip.FullAddress{1, stackV6Addr, stackPort}, + }, + { + name: "IPv6 multicast", + proto: header.IPv6ProtocolNumber, + flow: multicastV6, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedOriginalDstAddr: tcpip.FullAddress{1, multicastV6Addr, stackPort}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(test.proto) + + bindAddr := tcpip.FullAddress{Port: stackPort} + if err := c.ep.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%+v): %s", bindAddr, err) + } + + if test.flow.isMulticast() { + ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()} + if err := c.ep.SetSockOpt(&ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err) + } + } + + c.ep.SocketOptions().SetReceiveOriginalDstAddress(true) + + testRead(c, test.flow, checker.ReceiveOriginalDstAddr(test.expectedOriginalDstAddr)) + + if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { + t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) + } + }) + } +} + func TestWriteIncrementsPacketsSent(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() -- cgit v1.2.3