diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/abi/linux/socket.go | 13 | ||||
-rw-r--r-- | pkg/binary/binary.go | 10 | ||||
-rw-r--r-- | pkg/sentry/fs/fsutil/host_file_mapper.go | 17 | ||||
-rw-r--r-- | pkg/sentry/fs/tty/slave.go | 2 | ||||
-rw-r--r-- | pkg/sentry/fsimpl/gofer/regular_file.go | 5 | ||||
-rw-r--r-- | pkg/sentry/socket/control/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/socket/control/control.go | 69 | ||||
-rw-r--r-- | pkg/sentry/socket/hostinet/socket.go | 11 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/extensions.go | 7 | ||||
-rw-r--r-- | pkg/sentry/socket/netlink/message.go | 15 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 37 | ||||
-rw-r--r-- | pkg/sentry/strace/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/strace/socket.go | 7 | ||||
-rw-r--r-- | pkg/sleep/commit_noasm.go | 13 | ||||
-rw-r--r-- | pkg/sleep/sleep_unsafe.go | 23 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 85 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 25 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 26 |
20 files changed, 304 insertions, 77 deletions
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 766ee4014..4a14ef691 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -411,6 +411,15 @@ type ControlMessageCredentials struct { GID uint32 } +// A ControlMessageIPPacketInfo is IP_PKTINFO socket control message. +// +// ControlMessageIPPacketInfo represents struct in_pktinfo from linux/in.h. +type ControlMessageIPPacketInfo struct { + NIC int32 + LocalAddr InetAddr + DestinationAddr InetAddr +} + // SizeOfControlMessageCredentials is the binary size of a // ControlMessageCredentials struct. var SizeOfControlMessageCredentials = int(binary.Size(ControlMessageCredentials{})) @@ -431,6 +440,10 @@ const SizeOfControlMessageTOS = 1 // SizeOfControlMessageTClass is the size of an IPV6_TCLASS control message. const SizeOfControlMessageTClass = 4 +// SizeOfControlMessageIPPacketInfo is the size of an IP_PKTINFO +// control message. +const SizeOfControlMessageIPPacketInfo = 12 + // SCM_MAX_FD is the maximum number of FDs accepted in a single sendmsg call. // From net/scm.h. const SCM_MAX_FD = 253 diff --git a/pkg/binary/binary.go b/pkg/binary/binary.go index 631785f7b..25065aef9 100644 --- a/pkg/binary/binary.go +++ b/pkg/binary/binary.go @@ -254,3 +254,13 @@ func WriteUint64(w io.Writer, order binary.ByteOrder, num uint64) error { _, err := w.Write(buf) return err } + +// AlignUp rounds a length up to an alignment. align must be a power of 2. +func AlignUp(length int, align uint) int { + return (length + int(align) - 1) & ^(int(align) - 1) +} + +// AlignDown rounds a length down to an alignment. align must be a power of 2. +func AlignDown(length int, align uint) int { + return length & ^(int(align) - 1) +} diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index 67278aa86..e82afd112 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -65,13 +65,18 @@ type mapping struct { writable bool } -// NewHostFileMapper returns a HostFileMapper with no references or cached -// mappings. +// Init must be called on zero-value HostFileMappers before first use. +func (f *HostFileMapper) Init() { + f.refs = make(map[uint64]int32) + f.mappings = make(map[uint64]mapping) +} + +// NewHostFileMapper returns an initialized HostFileMapper allocated on the +// heap with no references or cached mappings. func NewHostFileMapper() *HostFileMapper { - return &HostFileMapper{ - refs: make(map[uint64]int32), - mappings: make(map[uint64]mapping), - } + f := &HostFileMapper{} + f.Init() + return f } // IncRefOn increments the reference count on all offsets in mr. diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go index db55cdc48..6a2dbc576 100644 --- a/pkg/sentry/fs/tty/slave.go +++ b/pkg/sentry/fs/tty/slave.go @@ -73,7 +73,7 @@ func (si *slaveInodeOperations) Release(ctx context.Context) { } // Truncate implements fs.InodeOperations.Truncate. -func (slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error { +func (*slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error { return nil } diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 8e11e06b3..54c1031a7 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -571,6 +571,8 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt default: panic(fmt.Sprintf("unknown InteropMode %v", d.fs.opts.interop)) } + // After this point, d may be used as a memmap.Mappable. + d.pf.hostFileMapperInitOnce.Do(d.pf.hostFileMapper.Init) return vfs.GenericConfigureMMap(&fd.vfsfd, d, opts) } @@ -799,6 +801,9 @@ type dentryPlatformFile struct { // If this dentry represents a regular file, and handle.fd >= 0, // hostFileMapper caches mappings of handle.fd. hostFileMapper fsutil.HostFileMapper + + // hostFileMapperInitOnce is used to lazily initialize hostFileMapper. + hostFileMapperInitOnce sync.Once } // IncRef implements platform.File.IncRef. diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index 79e16d6e8..4d42d29cb 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -19,6 +19,7 @@ go_library( "//pkg/sentry/socket", "//pkg/sentry/socket/unix/transport", "//pkg/syserror", + "//pkg/tcpip", "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 00265f15b..4667373d2 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" ) @@ -189,7 +190,7 @@ func putUint32(buf []byte, n uint32) []byte { // putCmsg writes a control message header and as much data as will fit into // the unused capacity of a buffer. func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([]byte, int) { - space := AlignDown(cap(buf)-len(buf), 4) + space := binary.AlignDown(cap(buf)-len(buf), 4) // We can't write to space that doesn't exist, so if we are going to align // the available space, we must align down. @@ -282,19 +283,9 @@ func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte, flags int return putCmsg(buf, flags, linux.SCM_CREDENTIALS, align, c) } -// AlignUp rounds a length up to an alignment. align must be a power of 2. -func AlignUp(length int, align uint) int { - return (length + int(align) - 1) & ^(int(align) - 1) -} - -// AlignDown rounds a down to an alignment. align must be a power of 2. -func AlignDown(length int, align uint) int { - return length & ^(int(align) - 1) -} - // alignSlice extends a slice's length (up to the capacity) to align it. func alignSlice(buf []byte, align uint) []byte { - aligned := AlignUp(len(buf), align) + aligned := binary.AlignUp(len(buf), align) if aligned > cap(buf) { // Linux allows unaligned data if there isn't room for alignment. // Since there isn't room for alignment, there isn't room for any @@ -348,6 +339,22 @@ func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte { ) } +// PackIPPacketInfo packs an IP_PKTINFO socket control message. +func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte { + var p linux.ControlMessageIPPacketInfo + p.NIC = int32(packetInfo.NIC) + copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr)) + copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr)) + + return putCmsgStruct( + buf, + linux.SOL_IP, + linux.IP_PKTINFO, + t.Arch().Width(), + p, + ) +} + // PackControlMessages packs control messages into the given buffer. // // We skip control messages specific to Unix domain sockets. @@ -372,12 +379,16 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt buf = PackTClass(t, cmsgs.IP.TClass, buf) } + if cmsgs.IP.HasIPPacketInfo { + buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf) + } + return buf } // cmsgSpace is equivalent to CMSG_SPACE in Linux. func cmsgSpace(t *kernel.Task, dataLen int) int { - return linux.SizeOfControlMessageHeader + AlignUp(dataLen, t.Arch().Width()) + return linux.SizeOfControlMessageHeader + binary.AlignUp(dataLen, t.Arch().Width()) } // CmsgsSpace returns the number of bytes needed to fit the control messages @@ -404,6 +415,16 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { return space } +// NewIPPacketInfo returns the IPPacketInfo struct. +func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo { + var p tcpip.IPPacketInfo + p.NIC = tcpip.NICID(packetInfo.NIC) + copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:]) + copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:]) + + return p +} + // Parse parses a raw socket control message into portable objects. func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) { var ( @@ -437,7 +458,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con case linux.SOL_SOCKET: switch h.Type { case linux.SCM_RIGHTS: - rightsSize := AlignDown(length, linux.SizeOfControlMessageRight) + rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight) numRights := rightsSize / linux.SizeOfControlMessageRight if len(fds)+numRights > linux.SCM_MAX_FD { @@ -448,7 +469,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight]))) } - i += AlignUp(length, width) + i += binary.AlignUp(length, width) case linux.SCM_CREDENTIALS: if length < linux.SizeOfControlMessageCredentials { @@ -462,7 +483,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con return socket.ControlMessages{}, err } cmsgs.Unix.Credentials = scmCreds - i += AlignUp(length, width) + i += binary.AlignUp(length, width) default: // Unknown message type. @@ -476,7 +497,19 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con } cmsgs.IP.HasTOS = true binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS) - i += AlignUp(length, width) + i += binary.AlignUp(length, width) + + case linux.IP_PKTINFO: + if length < linux.SizeOfControlMessageIPPacketInfo { + return socket.ControlMessages{}, syserror.EINVAL + } + + cmsgs.IP.HasIPPacketInfo = true + var packetInfo linux.ControlMessageIPPacketInfo + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) + + cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo) + i += binary.AlignUp(length, width) default: return socket.ControlMessages{}, syserror.EINVAL @@ -489,7 +522,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con } cmsgs.IP.HasTClass = true binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass) - i += AlignUp(length, width) + i += binary.AlignUp(length, width) default: return socket.ControlMessages{}, syserror.EINVAL diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index de76388ac..22f78d2e2 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -289,7 +289,7 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt switch level { case linux.SOL_IP: switch name { - case linux.IP_TOS, linux.IP_RECVTOS: + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO: optlen = sizeofInt32 } case linux.SOL_IPV6: @@ -336,6 +336,8 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ switch name { case linux.IP_TOS, linux.IP_RECVTOS: optlen = sizeofInt32 + case linux.IP_PKTINFO: + optlen = linux.SizeOfControlMessageIPPacketInfo } case linux.SOL_IPV6: switch name { @@ -473,7 +475,14 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags case syscall.IP_TOS: controlMessages.IP.HasTOS = true binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS) + + case syscall.IP_PKTINFO: + controlMessages.IP.HasIPPacketInfo = true + var packetInfo linux.ControlMessageIPPacketInfo + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) + controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo) } + case syscall.SOL_IPV6: switch unixCmsg.Header.Type { case syscall.IPV6_TCLASS: diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index 22fd0ebe7..b4b244abf 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -72,7 +72,7 @@ func marshalEntryMatch(name string, data []byte) []byte { nflog("marshaling matcher %q", name) // We have to pad this struct size to a multiple of 8 bytes. - size := alignUp(linux.SizeOfXTEntryMatch+len(data), 8) + size := binary.AlignUp(linux.SizeOfXTEntryMatch+len(data), 8) matcher := linux.KernelXTEntryMatch{ XTEntryMatch: linux.XTEntryMatch{ MatchSize: uint16(size), @@ -93,8 +93,3 @@ func unmarshalMatcher(match linux.XTEntryMatch, filter iptables.IPHeaderFilter, } return matchMaker.unmarshal(buf, filter) } - -// alignUp rounds a length up to an alignment. align must be a power of 2. -func alignUp(length int, align uint) int { - return (length + int(align) - 1) & ^(int(align) - 1) -} diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go index 4ea252ccb..0899c61d1 100644 --- a/pkg/sentry/socket/netlink/message.go +++ b/pkg/sentry/socket/netlink/message.go @@ -23,18 +23,11 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// alignUp rounds a length up to an alignment. -// -// Preconditions: align is a power of two. -func alignUp(length int, align uint) int { - return (length + int(align) - 1) &^ (int(align) - 1) -} - // alignPad returns the length of padding required for alignment. // // Preconditions: align is a power of two. func alignPad(length int, align uint) int { - return alignUp(length, align) - length + return binary.AlignUp(length, align) - length } // Message contains a complete serialized netlink message. @@ -138,7 +131,7 @@ func (m *Message) Finalize() []byte { // Align the message. Note that the message length in the header (set // above) is the useful length of the message, not the total aligned // length. See net/netlink/af_netlink.c:__nlmsg_put. - aligned := alignUp(len(m.buf), linux.NLMSG_ALIGNTO) + aligned := binary.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO) m.putZeros(aligned - len(m.buf)) return m.buf } @@ -173,7 +166,7 @@ func (m *Message) PutAttr(atype uint16, v interface{}) { m.Put(v) // Align the attribute. - aligned := alignUp(l, linux.NLA_ALIGNTO) + aligned := binary.AlignUp(l, linux.NLA_ALIGNTO) m.putZeros(aligned - l) } @@ -190,7 +183,7 @@ func (m *Message) PutAttrString(atype uint16, s string) { m.putZeros(1) // Align the attribute. - aligned := alignUp(l, linux.NLA_ALIGNTO) + aligned := binary.AlignUp(l, linux.NLA_ALIGNTO) m.putZeros(aligned - l) } diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index ed2fbcceb..9757fbfba 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1414,6 +1414,21 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in } return o, nil + case linux.IP_PKTINFO: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + var o int32 + if v { + o = 1 + } + return o, nil + default: emitUnimplementedEventIP(t, name) } @@ -1762,6 +1777,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) linux.IPV6_IPSEC_POLICY, linux.IPV6_JOIN_ANYCAST, linux.IPV6_LEAVE_ANYCAST, + // TODO(b/148887420): Add support for IPV6_PKTINFO. linux.IPV6_PKTINFO, linux.IPV6_ROUTER_ALERT, linux.IPV6_XFRM_POLICY, @@ -1949,6 +1965,16 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0)) + case linux.IP_PKTINFO: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0)) + case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, @@ -1964,7 +1990,6 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s linux.IP_NODEFRAG, linux.IP_OPTIONS, linux.IP_PASSSEC, - linux.IP_PKTINFO, linux.IP_RECVERR, linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, @@ -2395,10 +2420,12 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe func (s *SocketOperations) controlMessages() socket.ControlMessages { return socket.ControlMessages{ IP: tcpip.ControlMessages{ - HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, - Timestamp: s.readCM.Timestamp, - HasTOS: s.readCM.HasTOS, - TOS: s.readCM.TOS, + HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, + Timestamp: s.readCM.Timestamp, + HasTOS: s.readCM.HasTOS, + TOS: s.readCM.TOS, + HasIPPacketInfo: s.readCM.HasIPPacketInfo, + PacketInfo: s.readCM.PacketInfo, }, } } diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index 762a946fe..2f39a6f2b 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -30,7 +30,6 @@ go_library( "//pkg/seccomp", "//pkg/sentry/arch", "//pkg/sentry/kernel", - "//pkg/sentry/socket/control", "//pkg/sentry/socket/netlink", "//pkg/sentry/socket/netstack", "//pkg/sentry/syscalls/linux", diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index f7ff4573e..51e6d81b2 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/socket/control" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" "gvisor.dev/gvisor/pkg/sentry/socket/netstack" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" @@ -220,13 +219,13 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64) if skipData { strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length)) - i += control.AlignUp(length, width) + i += binary.AlignUp(length, width) continue } switch h.Type { case linux.SCM_RIGHTS: - rightsSize := control.AlignDown(length, linux.SizeOfControlMessageRight) + rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight) numRights := rightsSize / linux.SizeOfControlMessageRight fds := make(linux.ControlMessageRights, numRights) @@ -295,7 +294,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64) default: panic("unreachable") } - i += control.AlignUp(length, width) + i += binary.AlignUp(length, width) } return fmt.Sprintf("%#x %s", addr, strings.Join(strs, ", ")) diff --git a/pkg/sleep/commit_noasm.go b/pkg/sleep/commit_noasm.go index 3af447fb9..f59061f37 100644 --- a/pkg/sleep/commit_noasm.go +++ b/pkg/sleep/commit_noasm.go @@ -28,15 +28,6 @@ import "sync/atomic" // It is written in assembly because it is called from g0, so it doesn't have // a race context. func commitSleep(g uintptr, waitingG *uintptr) bool { - for { - // Check if the wait was aborted. - if atomic.LoadUintptr(waitingG) == 0 { - return false - } - - // Try to store the G so that wakers know who to wake. - if atomic.CompareAndSwapUintptr(waitingG, preparingG, g) { - return true - } - } + // Try to store the G so that wakers know who to wake. + return atomic.CompareAndSwapUintptr(waitingG, preparingG, g) } diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go index acbf0229b..65bfcf778 100644 --- a/pkg/sleep/sleep_unsafe.go +++ b/pkg/sleep/sleep_unsafe.go @@ -299,20 +299,17 @@ func (s *Sleeper) enqueueAssertedWaker(w *Waker) { } } - for { - // Nothing to do if there isn't a G waiting. - g := atomic.LoadUintptr(&s.waitingG) - if g == 0 { - return - } + // Nothing to do if there isn't a G waiting. + if atomic.LoadUintptr(&s.waitingG) == 0 { + return + } - // Signal to the sleeper that a waker has been asserted. - if atomic.CompareAndSwapUintptr(&s.waitingG, g, 0) { - if g != preparingG { - // We managed to get a G. Wake it up. - goready(g, 0) - } - } + // Signal to the sleeper that a waker has been asserted. + switch g := atomic.SwapUintptr(&s.waitingG, 0); g { + case 0, preparingG: + default: + // We managed to get a G. Wake it up. + goready(g, 0) } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 78d451cca..ca3a7a07e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1215,6 +1215,11 @@ func (n *NIC) ID() tcpip.NICID { return n.id } +// Name returns the name of n. +func (n *NIC) Name() string { + return n.name +} + // Stack returns the instance of the Stack that owns this NIC. func (n *NIC) Stack() *Stack { return n.stack diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index b793f1d74..6eac16e16 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -890,6 +890,15 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp return tcpip.ErrDuplicateNICID } + // Make sure name is unique, unless unnamed. + if opts.Name != "" { + for _, n := range s.nics { + if n.Name() == opts.Name { + return tcpip.ErrDuplicateNICID + } + } + } + n := newNIC(s, id, opts.Name, ep, opts.Context) s.nics[id] = n diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 24133e6f2..7ba604442 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1792,6 +1792,91 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { verifyAddresses(t, expectedAddresses, gotAddresses) } +func TestCreateNICWithOptions(t *testing.T) { + type callArgsAndExpect struct { + nicID tcpip.NICID + opts stack.NICOptions + err *tcpip.Error + } + + tests := []struct { + desc string + calls []callArgsAndExpect + }{ + { + desc: "DuplicateNICID", + calls: []callArgsAndExpect{ + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{Name: "eth1"}, + err: nil, + }, + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{Name: "eth2"}, + err: tcpip.ErrDuplicateNICID, + }, + }, + }, + { + desc: "DuplicateName", + calls: []callArgsAndExpect{ + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{Name: "lo"}, + err: nil, + }, + { + nicID: tcpip.NICID(2), + opts: stack.NICOptions{Name: "lo"}, + err: tcpip.ErrDuplicateNICID, + }, + }, + }, + { + desc: "Unnamed", + calls: []callArgsAndExpect{ + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{}, + err: nil, + }, + { + nicID: tcpip.NICID(2), + opts: stack.NICOptions{}, + err: nil, + }, + }, + }, + { + desc: "UnnamedDuplicateNICID", + calls: []callArgsAndExpect{ + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{}, + err: nil, + }, + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{}, + err: tcpip.ErrDuplicateNICID, + }, + }, + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + s := stack.New(stack.Options{}) + ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + for _, call := range test.calls { + if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want { + t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want) + } + } + }) + } +} + func TestNICStats(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 0e944712f..9ca39ce40 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -328,6 +328,12 @@ type ControlMessages struct { // Tclass is the IPv6 traffic class of the associated packet. TClass int32 + + // HasIPPacketInfo indicates whether PacketInfo is set. + HasIPPacketInfo bool + + // PacketInfo holds interface and address data on an incoming packet. + PacketInfo IPPacketInfo } // Endpoint is the interface implemented by transport protocols (e.g., tcp, udp) @@ -503,6 +509,11 @@ const ( // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 // socket is to be restricted to sending and receiving IPv6 packets only. V6OnlyOption + + // ReceiveIPPacketInfoOption is used by {G,S}etSockOptBool to specify + // if more inforamtion is provided with incoming packets such + // as interface index and address. + ReceiveIPPacketInfoOption ) // SockOptInt represents socket options which values have the int type. @@ -685,6 +696,20 @@ type IPv4TOSOption uint8 // for all subsequent outgoing IPv6 packets from the endpoint. type IPv6TrafficClassOption uint8 +// IPPacketInfo is the message struture for IP_PKTINFO. +// +// +stateify savable +type IPPacketInfo struct { + // NIC is the ID of the NIC to be used. + NIC NICID + + // LocalAddr is the local address. + LocalAddr Address + + // DestinationAddr is the destination address. + DestinationAddr Address +} + // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index c9cbed8f4..3fe91cac2 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -29,6 +29,7 @@ import ( type udpPacket struct { udpPacketEntry senderAddress tcpip.FullAddress + packetInfo tcpip.IPPacketInfo data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 tos uint8 @@ -118,6 +119,9 @@ type endpoint struct { // as ancillary data to ControlMessages on Read. receiveTOS bool + // receiveIPPacketInfo determines if the packet info is returned by Read. + receiveIPPacketInfo bool + // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -254,11 +258,17 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess } e.mu.RLock() receiveTOS := e.receiveTOS + receiveIPPacketInfo := e.receiveIPPacketInfo e.mu.RUnlock() if receiveTOS { cm.HasTOS = true cm.TOS = p.tos } + + if receiveIPPacketInfo { + cm.HasIPPacketInfo = true + cm.PacketInfo = p.packetInfo + } return p.data.ToView(), cm, nil } @@ -495,6 +505,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { } e.v6only = v + return nil + + case tcpip.ReceiveIPPacketInfoOption: + e.mu.Lock() + e.receiveIPPacketInfo = v + e.mu.Unlock() + return nil } return nil @@ -703,6 +720,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { e.mu.RUnlock() return v, nil + + case tcpip.ReceiveIPPacketInfoOption: + e.mu.RLock() + v := e.receiveIPPacketInfo + e.mu.RUnlock() + return v, nil } return false, tcpip.ErrUnknownProtocolOption @@ -1247,6 +1270,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk switch r.NetProto { case header.IPv4ProtocolNumber: packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS() + packet.packetInfo.LocalAddr = r.LocalAddress + packet.packetInfo.DestinationAddr = r.RemoteAddress + packet.packetInfo.NIC = r.NICID() } packet.timestamp = e.stack.NowNanoseconds() |