diff options
62 files changed, 900 insertions, 359 deletions
diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go index a6cc0617e..de9357389 100644 --- a/pkg/p9/client_file.go +++ b/pkg/p9/client_file.go @@ -17,7 +17,6 @@ package p9 import ( "fmt" "io" - "runtime" "sync/atomic" "syscall" @@ -45,15 +44,10 @@ func (c *Client) Attach(name string) (File, error) { // newFile returns a new client file. func (c *Client) newFile(fid FID) *clientFile { - cf := &clientFile{ + return &clientFile{ client: c, fid: fid, } - - // Make sure the file is closed. - runtime.SetFinalizer(cf, (*clientFile).Close) - - return cf } // clientFile is provided to clients. @@ -192,7 +186,6 @@ func (c *clientFile) Remove() error { if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { return syscall.EBADF } - runtime.SetFinalizer(c, nil) // Send the remove message. if err := c.client.sendRecv(&Tremove{FID: c.fid}, &Rremove{}); err != nil { @@ -214,7 +207,6 @@ func (c *clientFile) Close() error { if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { return syscall.EBADF } - runtime.SetFinalizer(c, nil) // Send the close message. if err := c.client.sendRecv(&Tclunk{FID: c.fid}, &Rclunk{}); err != nil { diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go index f4ddfa406..2d43dff1d 100644 --- a/pkg/sentry/fs/inode.go +++ b/pkg/sentry/fs/inode.go @@ -344,6 +344,10 @@ func (i *Inode) SetTimestamps(ctx context.Context, d *Dirent, ts TimeSpec) error // Truncate calls i.InodeOperations.Truncate with i as the Inode. func (i *Inode) Truncate(ctx context.Context, d *Dirent, size int64) error { + if IsDir(i.StableAttr) { + return syserror.EISDIR + } + if i.overlay != nil { return overlayTruncate(ctx, i.overlay, d, size) } diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index 5a388dad1..a09147080 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -462,7 +462,9 @@ func overlayBind(ctx context.Context, o *overlayEntry, parent *Dirent, name stri inode.DecRef() return nil, err } - return NewDirent(ctx, newOverlayInode(ctx, entry, inode.MountSource), name), nil + // Use the parent's MountSource, since that corresponds to the overlay, + // and not the upper filesystem. + return NewDirent(ctx, newOverlayInode(ctx, entry, parent.Inode.MountSource), name), nil } func overlayBoundEndpoint(o *overlayEntry, path string) transport.BoundEndpoint { diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go index 19b7557d5..bc56be696 100644 --- a/pkg/sentry/fs/tty/master.go +++ b/pkg/sentry/fs/tty/master.go @@ -32,6 +32,7 @@ import ( // +stateify savable type masterInodeOperations struct { fsutil.SimpleFileInode + fsutil.InodeNoopTruncate // d is the containing dir. d *dirInodeOperations diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go index 944c4ada1..4cbea0367 100644 --- a/pkg/sentry/fs/tty/slave.go +++ b/pkg/sentry/fs/tty/slave.go @@ -31,6 +31,7 @@ import ( // +stateify savable type slaveInodeOperations struct { fsutil.SimpleFileInode + fsutil.InodeNoopTruncate // d is the containing dir. d *dirInodeOperations diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index b9a8e3e21..167c2b60b 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -169,10 +169,11 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint if dirPath { return syserror.ENOTDIR } - if flags&linux.O_TRUNC != 0 { - if err := d.Inode.Truncate(t, d, 0); err != nil { - return err - } + } + + if flags&linux.O_TRUNC != 0 { + if err := d.Inode.Truncate(t, d, 0); err != nil { + return err } } diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index b5a72ce63..ab1001f16 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -447,16 +447,13 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.ENOTSOCK } - // Read the length if present. Reject negative values. + // Read the length. Reject negative values. optLen := int32(0) - if optLenAddr != 0 { - if _, err := t.CopyIn(optLenAddr, &optLen); err != nil { - return 0, nil, err - } - - if optLen < 0 { - return 0, nil, syserror.EINVAL - } + if _, err := t.CopyIn(optLenAddr, &optLen); err != nil { + return 0, nil, err + } + if optLen < 0 { + return 0, nil, syserror.EINVAL } // Call syscall implementation then copy both value and value len out. @@ -465,11 +462,9 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, e.ToError() } - if optLenAddr != 0 { - vLen := int32(binary.Size(v)) - if _, err := t.CopyOut(optLenAddr, vLen); err != nil { - return 0, nil, err - } + vLen := int32(binary.Size(v)) + if _, err := t.CopyOut(optLenAddr, vLen); err != nil { + return 0, nil, err } if v != nil { diff --git a/pkg/sentry/usermem/bytes_io.go b/pkg/sentry/usermem/bytes_io.go index 8d88396ba..7898851b3 100644 --- a/pkg/sentry/usermem/bytes_io.go +++ b/pkg/sentry/usermem/bytes_io.go @@ -102,19 +102,34 @@ func (b *BytesIO) rangeCheck(addr Addr, length int) (int, error) { } func (b *BytesIO) blocksFromAddrRanges(ars AddrRangeSeq) (safemem.BlockSeq, error) { - blocks := make([]safemem.Block, 0, ars.NumRanges()) - for !ars.IsEmpty() { - ar := ars.Head() - n, err := b.rangeCheck(ar.Start, int(ar.Length())) - if n != 0 { - blocks = append(blocks, safemem.BlockFromSafeSlice(b.Bytes[int(ar.Start):int(ar.Start)+n])) + switch ars.NumRanges() { + case 0: + return safemem.BlockSeq{}, nil + case 1: + block, err := b.blockFromAddrRange(ars.Head()) + return safemem.BlockSeqOf(block), err + default: + blocks := make([]safemem.Block, 0, ars.NumRanges()) + for !ars.IsEmpty() { + block, err := b.blockFromAddrRange(ars.Head()) + if block.Len() != 0 { + blocks = append(blocks, block) + } + if err != nil { + return safemem.BlockSeqFromSlice(blocks), err + } + ars = ars.Tail() } - if err != nil { - return safemem.BlockSeqFromSlice(blocks), err - } - ars = ars.Tail() + return safemem.BlockSeqFromSlice(blocks), nil + } +} + +func (b *BytesIO) blockFromAddrRange(ar AddrRange) (safemem.Block, error) { + n, err := b.rangeCheck(ar.Start, int(ar.Length())) + if n == 0 { + return safemem.Block{}, err } - return safemem.BlockSeqFromSlice(blocks), nil + return safemem.BlockFromSafeSlice(b.Bytes[int(ar.Start) : int(ar.Start)+n]), err } // BytesIOSequence returns an IOSequence representing the given byte slice. diff --git a/pkg/tcpip/buffer/prependable.go b/pkg/tcpip/buffer/prependable.go index 48a2a2713..ba21f4eca 100644 --- a/pkg/tcpip/buffer/prependable.go +++ b/pkg/tcpip/buffer/prependable.go @@ -77,3 +77,9 @@ func (p *Prependable) Prepend(size int) []byte { p.usedIdx -= size return p.View()[:size:size] } + +// DeepCopy copies p and the bytes backing it. +func (p Prependable) DeepCopy() Prependable { + p.buf = append(View(nil), p.buf...) + return p +} diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 22eefb564..9fe8e9f9d 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -25,10 +25,9 @@ import ( // PacketInfo holds all the information about an outbound packet. type PacketInfo struct { - Header buffer.View - Payload buffer.View - Proto tcpip.NetworkProtocolNumber - GSO *stack.GSO + Pkt tcpip.PacketBuffer + Proto tcpip.NetworkProtocolNumber + GSO *stack.GSO } // Endpoint is link layer endpoint that stores outbound packets in a channel @@ -118,12 +117,11 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { p := PacketInfo{ - Header: hdr.View(), - Proto: protocol, - Payload: payload.ToView(), - GSO: gso, + Pkt: pkt, + Proto: protocol, + GSO: gso, } select { @@ -139,15 +137,16 @@ func (e *Endpoint) WritePackets(_ *stack.Route, gso *stack.GSO, hdrs []stack.Pac payloadView := payload.ToView() n := 0 packetLoop: - for i := range hdrs { - hdr := &hdrs[i].Hdr - off := hdrs[i].Off - size := hdrs[i].Size + for _, hdr := range hdrs { + off := hdr.Off + size := hdr.Size p := PacketInfo{ - Header: hdr.View(), - Proto: protocol, - Payload: buffer.NewViewFromBytes(payloadView[off : off+size]), - GSO: gso, + Pkt: tcpip.PacketBuffer{ + Header: hdr.Hdr, + Data: buffer.NewViewFromBytes(payloadView[off : off+size]).ToVectorisedView(), + }, + Proto: protocol, + GSO: gso, } select { @@ -162,12 +161,11 @@ packetLoop: } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *Endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { +func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := PacketInfo{ - Header: packet.ToView(), - Proto: 0, - Payload: buffer.View{}, - GSO: nil, + Pkt: tcpip.PacketBuffer{Data: vv}, + Proto: 0, + GSO: nil, } select { diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index edef7db26..98109c5dc 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -386,10 +386,11 @@ const ( // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { if e.hdrSize > 0 { // Add ethernet header if needed. - eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize)) + eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) + pkt.LinkHeader = buffer.View(eth) ethHdr := &header.EthernetFields{ DstAddr: r.RemoteLinkAddress, Type: protocol, @@ -408,13 +409,13 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen vnetHdr := virtioNetHdr{} vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr) if gso != nil { - vnetHdr.hdrLen = uint16(hdr.UsedLength()) + vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) if gso.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen vnetHdr.csumOffset = gso.CsumOffset } - if gso.Type != stack.GSONone && uint16(payload.Size()) > gso.MSS { + if gso.Type != stack.GSONone && uint16(pkt.Data.Size()) > gso.MSS { switch gso.Type { case stack.GSOTCPv4: vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 @@ -427,14 +428,14 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen } } - return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, hdr.View(), payload.ToView()) + return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView()) } - if payload.Size() == 0 { - return rawfile.NonBlockingWrite(e.fds[0], hdr.View()) + if pkt.Data.Size() == 0 { + return rawfile.NonBlockingWrite(e.fds[0], pkt.Header.View()) } - return rawfile.NonBlockingWrite3(e.fds[0], hdr.View(), payload.ToView(), nil) + return rawfile.NonBlockingWrite3(e.fds[0], pkt.Header.View(), pkt.Data.ToView(), nil) } // WritePackets writes outbound packets to the file descriptor. If it is not @@ -555,8 +556,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.Pac } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { - return rawfile.NonBlockingWrite(e.fds[0], packet.ToView()) +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + return rawfile.NonBlockingWrite(e.fds[0], vv.ToView()) } // InjectOutobund implements stack.InjectableEndpoint.InjectOutbound. diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 7e08e033b..2066987eb 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -168,7 +168,10 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) { L3HdrLen: header.IPv4MaximumHeaderSize, } } - if err := c.ep.WritePacket(r, gso, hdr, payload.ToVectorisedView(), proto); err != nil { + if err := c.ep.WritePacket(r, gso, proto, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -258,7 +261,10 @@ func TestPreserveSrcAddress(t *testing.T) { // WritePacket panics given a prependable with anything less than // the minimum size of the ethernet header. hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil { + if err := c.ep.WritePacket(r, nil /* gso */, proto, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.VectorisedView{}, + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index bc5d8a2f3..563a67188 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -76,16 +76,16 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - - // Because we're immediately turning around and writing the packet back to the - // rx path, we intentionally don't preserve the remote and local link - // addresses from the stack.Route we're passed. +func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) + + // Because we're immediately turning around and writing the packet back + // to the rx path, we intentionally don't preserve the remote and local + // link addresses from the stack.Route we're passed. e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, tcpip.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+payload.Size(), views), + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) return nil @@ -97,17 +97,17 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, hdrs []stack.Packe } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { // Reject the packet if it's shorter than an ethernet header. - if packet.Size() < header.EthernetMinimumSize { + if vv.Size() < header.EthernetMinimumSize { return tcpip.ErrBadAddress } - // There should be an ethernet header at the beginning of packet. - linkHeader := header.Ethernet(packet.First()[:header.EthernetMinimumSize]) - packet.TrimFront(len(linkHeader)) + // There should be an ethernet header at the beginning of vv. + linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) + vv.TrimFront(len(linkHeader)) e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), tcpip.PacketBuffer{ - Data: packet, + Data: vv, LinkHeader: buffer.View(linkHeader), }) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 9a8e8ebfe..55ed2a28e 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -98,15 +98,15 @@ func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs [ // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint // based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a // route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { if endpoint, ok := m.routes[r.RemoteAddress]; ok { - return endpoint.WritePacket(r, gso, hdr, payload, protocol) + return endpoint.WritePacket(r, gso, protocol, pkt) } return tcpip.ErrNoRoute } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (m *InjectableEndpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { +func (m *InjectableEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { // WriteRawPacket doesn't get a route or network address, so there's // nowhere to write this. return tcpip.ErrNoRoute diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 9cd300af8..63b249837 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -50,8 +50,10 @@ func TestInjectableEndpointDispatch(t *testing.T) { hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, hdr, - buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), ipv4.ProtocolNumber) + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), + }) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) @@ -68,8 +70,10 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { hdr := buffer.NewPrependable(1) hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, hdr, - buffer.NewView(0).ToVectorisedView(), ipv4.ProtocolNumber) + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.NewView(0).ToVectorisedView(), + }) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) if err != nil { diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 2bace5298..88947a03a 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -185,9 +185,10 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { // Add the ethernet header here. - eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize)) + eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) + pkt.LinkHeader = buffer.View(eth) ethHdr := &header.EthernetFields{ DstAddr: r.RemoteLinkAddress, Type: protocol, @@ -199,10 +200,10 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependa } eth.Encode(ethHdr) - v := payload.ToView() + v := pkt.Data.ToView() // Transmit the packet. e.mu.Lock() - ok := e.tx.transmit(hdr.View(), v) + ok := e.tx.transmit(pkt.Header.View(), v) e.mu.Unlock() if !ok { @@ -218,8 +219,8 @@ func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, hdrs []stack.Packe } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { - v := packet.ToView() +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + v := vv.ToView() // Transmit the packet. e.mu.Lock() ok := e.tx.transmit(v, buffer.View{}) diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 199406886..89603c48f 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -273,7 +273,10 @@ func TestSimpleSend(t *testing.T) { randomFill(buf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), proto); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -342,7 +345,9 @@ func TestPreserveSrcAddressInSend(t *testing.T) { hdr := buffer.NewPrependable(header.EthernetMinimumSize) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{ + Header: hdr, + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -396,7 +401,10 @@ func TestFillTxQueue(t *testing.T) { for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -411,7 +419,10 @@ func TestFillTxQueue(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want { + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -436,7 +447,10 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Send two packets so that the id slice has at least two slots. for i := 2; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } @@ -456,7 +470,10 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -471,7 +488,10 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want { + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -494,7 +514,10 @@ func TestFillTxMemory(t *testing.T) { ids := make(map[uint64]struct{}) for i := queueDataSize / bufferSize; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -510,7 +533,10 @@ func TestFillTxMemory(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber) + err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }) if want := tcpip.ErrWouldBlock; err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } @@ -535,7 +561,10 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // until there is only one buffer left. for i := queueDataSize/bufferSize - 1; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -548,7 +577,10 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, uu, header.IPv4ProtocolNumber); err != want { + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: uu, + }); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -556,7 +588,10 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Attempt to write the one-buffer packet again. It must succeed. { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index d71a03cd2..122680e10 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -193,19 +193,19 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -func (e *endpoint) dumpPacket(gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) { +func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("send", protocol, hdr.View(), gso) + logPacket("send", protocol, pkt.Header.View(), gso) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { - hdrBuf := hdr.View() - length := len(hdrBuf) + payload.Size() + hdrBuf := pkt.Header.View() + length := len(hdrBuf) + pkt.Data.Size() if length > int(e.maxPCAPLen) { length = int(e.maxPCAPLen) } buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) - if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+payload.Size()))); err != nil { + if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+pkt.Data.Size()))); err != nil { panic(err) } if len(hdrBuf) > length { @@ -215,7 +215,7 @@ func (e *endpoint) dumpPacket(gso *stack.GSO, hdr buffer.Prependable, payload bu panic(err) } length -= len(hdrBuf) - logVectorisedView(payload, length, buf) + logVectorisedView(pkt.Data, length, buf) if _, err := e.file.Write(buf.Bytes()); err != nil { panic(err) } @@ -225,9 +225,9 @@ func (e *endpoint) dumpPacket(gso *stack.GSO, hdr buffer.Prependable, payload bu // WritePacket implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - e.dumpPacket(gso, hdr, payload, protocol) - return e.lower.WritePacket(r, gso, hdr, payload, protocol) +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { + e.dumpPacket(gso, protocol, pkt) + return e.lower.WritePacket(r, gso, protocol, pkt) } // WritePackets implements the stack.LinkEndpoint interface. It is called by @@ -236,32 +236,35 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { view := payload.ToView() for _, d := range hdrs { - e.dumpPacket(gso, d.Hdr, buffer.NewVectorisedView(d.Size, []buffer.View{view[d.Off:][:d.Size]}), protocol) + e.dumpPacket(gso, protocol, tcpip.PacketBuffer{ + Header: d.Hdr, + Data: view[d.Off:][:d.Size].ToVectorisedView(), + }) } return e.lower.WritePackets(r, gso, hdrs, payload, protocol) } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { logPacket("send", 0, buffer.View("[raw packet, no header available]"), nil /* gso */) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { - length := packet.Size() + length := vv.Size() if length > int(e.maxPCAPLen) { length = int(e.maxPCAPLen) } buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) - if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(packet.Size()))); err != nil { + if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil { panic(err) } - logVectorisedView(packet, length, buf) + logVectorisedView(vv, length, buf) if _, err := e.file.Write(buf.Bytes()); err != nil { panic(err) } } - return e.lower.WriteRawPacket(packet) + return e.lower.WriteRawPacket(vv) } func logVectorisedView(vv buffer.VectorisedView, length int, buf *bytes.Buffer) { diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index b440970e0..12e7c1932 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -99,12 +99,12 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket implements stack.LinkEndpoint.WritePacket. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { if !e.writeGate.Enter() { return nil } - err := e.lower.WritePacket(r, gso, hdr, payload, protocol) + err := e.lower.WritePacket(r, gso, protocol, pkt) e.writeGate.Leave() return err } @@ -123,12 +123,12 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.Pac } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *Endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { +func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { if !e.writeGate.Enter() { return nil } - err := e.lower.WriteRawPacket(packet) + err := e.lower.WriteRawPacket(vv) e.writeGate.Leave() return err } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index df2e70e54..0fc0c2ebe 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -65,7 +65,7 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { e.writeCount++ return nil } @@ -76,7 +76,7 @@ func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, hdrs []stac return len(hdrs), nil } -func (e *countedEndpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { +func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { e.writeCount++ return nil } @@ -89,21 +89,21 @@ func TestWaitWrite(t *testing.T) { wep := New(ep) // Write and check that it goes through. - wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) + wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) if want := 1; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on dispatches, then try to write. It must go through. wep.WaitDispatch() - wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) + wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on writes, then try to write. It must not go through. wep.WaitWrite() - wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) + wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 0ee509ebe..30aec9ba7 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -79,7 +79,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, stack.NetworkHeaderParams, stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketLooping, tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -88,7 +88,7 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketDescript return 0, tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -106,14 +106,16 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { return // we have no useful answer, ignore the request } hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize) - pkt := header.ARP(hdr.Prepend(header.ARPSize)) - pkt.SetIPv4OverEthernet() - pkt.SetOp(header.ARPReply) - copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:]) - copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget()) - copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender()) - copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender()) - e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + packet := header.ARP(hdr.Prepend(header.ARPSize)) + packet.SetIPv4OverEthernet() + packet.SetOp(header.ARPReply) + copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) + copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) + copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) + copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) + e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + }) fallthrough // also fill the cache from requests case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) @@ -165,7 +167,9 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. copy(h.ProtocolAddressSender(), localAddr) copy(h.ProtocolAddressTarget(), addr) - return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + }) } // ResolveStaticAddress implements stack.LinkAddressResolver. diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 47098bfdc..8e6048a21 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -110,13 +110,13 @@ func TestDirectRequest(t *testing.T) { for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { t.Run(strconv.Itoa(i), func(t *testing.T) { inject(address) - pkt := <-c.linkEP.C - if pkt.Proto != arp.ProtocolNumber { - t.Fatalf("expected ARP response, got network protocol number %d", pkt.Proto) + pi := <-c.linkEP.C + if pi.Proto != arp.ProtocolNumber { + t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) } - rep := header.ARP(pkt.Header) + rep := header.ARP(pi.Pkt.Header.View()) if !rep.IsValid() { - t.Fatalf("invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) + t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength()) } if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index fe499d47e..1de188738 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -150,24 +150,24 @@ func (*testObject) Wait() {} // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { var prot tcpip.TransportProtocolNumber var srcAddr tcpip.Address var dstAddr tcpip.Address if t.v4 { - h := header.IPv4(hdr.View()) + h := header.IPv4(pkt.Header.View()) prot = tcpip.TransportProtocolNumber(h.Protocol()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } else { - h := header.IPv6(hdr.View()) + h := header.IPv6(pkt.Header.View()) prot = tcpip.TransportProtocolNumber(h.NextHeader()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } - t.checkValues(prot, payload, srcAddr, dstAddr) + t.checkValues(prot, pkt.Data, srcAddr, dstAddr) return nil } @@ -239,7 +239,10 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -477,7 +480,10 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } } diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index ce771631c..32bf39e43 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -99,7 +99,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) { pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: vv, + TransportHeader: buffer.View(pkt), + }); err != nil { sent.Dropped.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index ac16c8add..040329a74 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -117,13 +117,14 @@ func (e *endpoint) GSOMaxSize() uint32 { } // writePacketFragments calls e.linkEP.WritePacket with each packet fragment to -// write. It assumes that the IP header is entirely in hdr but does not assume -// that only the IP header is in hdr. It assumes that the input packet's stated -// length matches the length of the hdr+payload. mtu includes the IP header and -// options. This does not support the DontFragment IP flag. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, mtu int) *tcpip.Error { +// write. It assumes that the IP header is entirely in pkt.Header but does not +// assume that only the IP header is in pkt.Header. It assumes that the input +// packet's stated length matches the length of the header+payload. mtu +// includes the IP header and options. This does not support the DontFragment +// IP flag. +func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt tcpip.PacketBuffer) *tcpip.Error { // This packet is too big, it needs to be fragmented. - ip := header.IPv4(hdr.View()) + ip := header.IPv4(pkt.Header.View()) flags := ip.Flags() // Update mtu to take into account the header, which will exist in all @@ -137,62 +138,77 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buff outerMTU := innerMTU + int(ip.HeaderLength()) offset := ip.FragmentOffset() - originalAvailableLength := hdr.AvailableLength() + originalAvailableLength := pkt.Header.AvailableLength() for i := 0; i < n; i++ { // Where possible, the first fragment that is sent has the same - // hdr.UsedLength() as the input packet. The link-layer endpoint may depends - // on this for looking at, eg, L4 headers. + // pkt.Header.UsedLength() as the input packet. The link-layer + // endpoint may depend on this for looking at, eg, L4 headers. h := ip if i > 0 { - hdr = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) - h = header.IPv4(hdr.Prepend(int(ip.HeaderLength()))) + pkt.Header = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) + h = header.IPv4(pkt.Header.Prepend(int(ip.HeaderLength()))) copy(h, ip[:ip.HeaderLength()]) } if i != n-1 { h.SetTotalLength(uint16(outerMTU)) h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset) } else { - h.SetTotalLength(uint16(h.HeaderLength()) + uint16(payload.Size())) + h.SetTotalLength(uint16(h.HeaderLength()) + uint16(pkt.Data.Size())) h.SetFlagsFragmentOffset(flags, offset) } h.SetChecksum(0) h.SetChecksum(^h.CalculateChecksum()) offset += uint16(innerMTU) if i > 0 { - newPayload := payload.Clone([]buffer.View{}) + newPayload := pkt.Data.Clone(nil) newPayload.CapLength(innerMTU) - if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + }); err != nil { return err } r.Stats().IP.PacketsSent.Increment() - payload.TrimFront(newPayload.Size()) + pkt.Data.TrimFront(newPayload.Size()) continue } - // Special handling for the first fragment because it comes from the hdr. - if outerMTU >= hdr.UsedLength() { - // This fragment can fit all of hdr and possibly some of payload, too. - newPayload := payload.Clone([]buffer.View{}) - newPayloadLength := outerMTU - hdr.UsedLength() + // Special handling for the first fragment because it comes + // from the header. + if outerMTU >= pkt.Header.UsedLength() { + // This fragment can fit all of pkt.Header and possibly + // some of pkt.Data, too. + newPayload := pkt.Data.Clone(nil) + newPayloadLength := outerMTU - pkt.Header.UsedLength() newPayload.CapLength(newPayloadLength) - if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + }); err != nil { return err } r.Stats().IP.PacketsSent.Increment() - payload.TrimFront(newPayloadLength) + pkt.Data.TrimFront(newPayloadLength) } else { - // The fragment is too small to fit all of hdr. - startOfHdr := hdr - startOfHdr.TrimBack(hdr.UsedLength() - outerMTU) + // The fragment is too small to fit all of pkt.Header. + startOfHdr := pkt.Header + startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) - if err := e.linkEP.WritePacket(r, gso, startOfHdr, emptyVV, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + Header: startOfHdr, + Data: emptyVV, + NetworkHeader: buffer.View(h), + }); err != nil { return err } r.Stats().IP.PacketsSent.Increment() - // Add the unused bytes of hdr into the payload that remains to be sent. - restOfHdr := hdr.View()[outerMTU:] + // Add the unused bytes of pkt.Header into the pkt.Data + // that remains to be sent. + restOfHdr := pkt.Header.View()[outerMTU:] tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)}) - tmp.Append(payload) - payload = tmp + tmp.Append(pkt.Data) + pkt.Data = tmp } } return nil @@ -222,17 +238,17 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { - ip := e.addIPHeader(r, &hdr, payload.Size(), params) +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() e.HandlePacket(&loopedR, tcpip.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+payload.Size(), views), + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), NetworkHeader: buffer.View(ip), }) @@ -241,10 +257,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen if loop&stack.PacketOut == 0 { return nil } - if hdr.UsedLength()+payload.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { - return e.writePacketFragments(r, gso, hdr, payload, int(e.linkEP.MTU())) + if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { + return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt) } - if err := e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { return err } r.Stats().IP.PacketsSent.Increment() @@ -270,16 +286,16 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.Pac // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - ip := header.IPv4(payload.First()) - if !ip.IsValid(payload.Size()) { + ip := header.IPv4(pkt.Data.First()) + if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } // Always set the total length. - ip.SetTotalLength(uint16(payload.Size())) + ip.SetTotalLength(uint16(pkt.Data.Size())) // Set the source address when zero. if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) { @@ -293,7 +309,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect // Set the packet ID when zero. if ip.ID() == 0 { id := uint32(0) - if payload.Size() > header.IPv4MaximumHeaderSize+8 { + if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 { // Packets of 68 bytes or less are required by RFC 791 to not be // fragmented, so we only assign ids to larger packets. id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1) @@ -306,18 +322,18 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect ip.SetChecksum(^ip.CalculateChecksum()) if loop&stack.PacketLoop != 0 { - e.HandlePacket(r, tcpip.PacketBuffer{ - Data: payload, - NetworkHeader: buffer.View(ip), - }) + e.HandlePacket(r, pkt.Clone()) } if loop&stack.PacketOut == 0 { return nil } - hdr := buffer.NewPrependableFromView(payload.ToView()) r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + + ip = ip[:ip.HeaderLength()] + pkt.Header = buffer.NewPrependableFromView(buffer.View(ip)) + pkt.Data.TrimFront(int(ip.HeaderLength())) + return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // HandlePacket is called by the link layer when new ipv4 packets arrive for diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 01dfb5f20..e900f1b45 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -113,12 +113,12 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. // comparePayloads compared the contents of all the packets against the contents // of the source packet. -func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packetInfo, mtu uint32) { +func compareFragments(t *testing.T, packets []tcpip.PacketBuffer, sourcePacketInfo tcpip.PacketBuffer, mtu uint32) { t.Helper() // Make a complete array of the sourcePacketInfo packet. source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) source = append(source, sourcePacketInfo.Header.View()...) - source = append(source, sourcePacketInfo.Payload.ToView()...) + source = append(source, sourcePacketInfo.Data.ToView()...) // Make a copy of the IP header, which will be modified in some fields to make // an expected header. @@ -132,7 +132,7 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe for i, packet := range packets { // Confirm that the packet is valid. allBytes := packet.Header.View().ToVectorisedView() - allBytes.Append(packet.Payload) + allBytes.Append(packet.Data) ip := header.IPv4(allBytes.ToView()) if !ip.IsValid(len(ip)) { t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) @@ -173,7 +173,7 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe type errorChannel struct { *channel.Endpoint - Ch chan packetInfo + Ch chan tcpip.PacketBuffer packetCollectorErrors []*tcpip.Error } @@ -183,17 +183,11 @@ type errorChannel struct { func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { return &errorChannel{ Endpoint: channel.New(size, mtu, linkAddr), - Ch: make(chan packetInfo, size), + Ch: make(chan tcpip.PacketBuffer, size), packetCollectorErrors: packetCollectorErrors, } } -// packetInfo holds all the information about an outbound packet. -type packetInfo struct { - Header buffer.Prependable - Payload buffer.VectorisedView -} - // Drain removes all outbound packets from the channel and counts them. func (e *errorChannel) Drain() int { c := 0 @@ -208,14 +202,9 @@ func (e *errorChannel) Drain() int { } // WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - p := packetInfo{ - Header: hdr, - Payload: payload, - } - +func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { select { - case e.Ch <- p: + case e.Ch <- pkt: default: } @@ -292,18 +281,21 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := packetInfo{ + source := tcpip.PacketBuffer{ Header: hdr, // Save the source payload because WritePacket will modify it. - Payload: payload.Clone([]buffer.View{}), + Data: payload.Clone(nil), } c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}) + err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) if err != nil { t.Errorf("err got %v, want %v", err, nil) } - var results []packetInfo + var results []tcpip.PacketBuffer L: for { select { @@ -345,7 +337,10 @@ func TestFragmentationErrors(t *testing.T) { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}) + err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 6629951c6..1c3410618 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -226,7 +226,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P // // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. - if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + }); err != nil { sent.Dropped.Increment() return } @@ -291,7 +293,10 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P copy(packet, h) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) - if err := r.WritePacket(nil /* gso */, hdr, pkt.Data, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: pkt.Data, + }); err != nil { sent.Dropped.Increment() return } @@ -417,7 +422,9 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. }) // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + }) } // ResolveStaticAddress implements stack.LinkAddressResolver. diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 6037a1ef8..335f634d5 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -55,7 +55,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } -func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.NetworkProtocolNumber) *tcpip.Error { +func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, tcpip.PacketBuffer) *tcpip.Error { return nil } @@ -276,22 +276,22 @@ type routeArgs struct { func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { t.Helper() - pkt := <-args.src.C + pi := <-args.src.C { - views := []buffer.View{pkt.Header, pkt.Payload} - size := len(pkt.Header) + len(pkt.Payload) + views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} + size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pkt.Proto, args.dst.LinkAddress(), tcpip.PacketBuffer{ + args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), tcpip.PacketBuffer{ Data: vv, }) } - if pkt.Proto != ProtocolNumber { - t.Errorf("unexpected protocol number %d", pkt.Proto) + if pi.Proto != ProtocolNumber { + t.Errorf("unexpected protocol number %d", pi.Proto) return } - ipv6 := header.IPv6(pkt.Header) + ipv6 := header.IPv6(pi.Pkt.Header.View()) transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) if transProto != header.ICMPv6ProtocolNumber { t.Errorf("unexpected transport protocol number %d", transProto) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 4cee848a1..8d1578ed9 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -112,17 +112,17 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { - ip := e.addIPHeader(r, &hdr, payload.Size(), params) +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() e.HandlePacket(&loopedR, tcpip.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+payload.Size(), views), + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), NetworkHeader: buffer.View(ip), }) @@ -133,7 +133,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen } r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber) + return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. @@ -158,7 +158,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.Pac // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet // supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { // TODO(b/119580726): Support IPv6 header-included packets. return tcpip.ErrNotSupported } diff --git a/pkg/tcpip/packet_buffer.go b/pkg/tcpip/packet_buffer.go index 10b04239d..695f7b188 100644 --- a/pkg/tcpip/packet_buffer.go +++ b/pkg/tcpip/packet_buffer.go @@ -31,12 +31,19 @@ type PacketBuffer struct { // or otherwise modified. Data buffer.VectorisedView + // Header holds the headers of outbound packets. As a packet is passed + // down the stack, each layer adds to Header. + Header buffer.Prependable + + // These fields are used by both inbound and outbound packets. They + // typically overlap with the Data and Header fields. + // // The bytes backing these views are immutable. Each field may be nil // if either it has not been set yet or no such header exists (e.g. // packets sent via loopback may not have a link header). // - // These fields may be Views into other Views. SR dosen't support this, - // so deep copies are necessary in some cases. + // These fields may be Views into other slices (either Data or Header). + // SR dosen't support this, so deep copies are necessary in some cases. LinkHeader buffer.View NetworkHeader buffer.View TransportHeader buffer.View @@ -44,11 +51,9 @@ type PacketBuffer struct { // Clone makes a copy of pk. It clones the Data field, which creates a new // VectorisedView but does not deep copy the underlying bytes. +// +// Clone also does not deep copy any of its other fields. func (pk PacketBuffer) Clone() PacketBuffer { - return PacketBuffer{ - Data: pk.Data.Clone(nil), - LinkHeader: pk.LinkHeader, - NetworkHeader: pk.NetworkHeader, - TransportHeader: pk.TransportHeader, - } + pk.Data = pk.Data.Clone(nil) + return pk } diff --git a/pkg/tcpip/packet_buffer_state.go b/pkg/tcpip/packet_buffer_state.go index 04c4cf136..ad3cc24fa 100644 --- a/pkg/tcpip/packet_buffer_state.go +++ b/pkg/tcpip/packet_buffer_state.go @@ -20,6 +20,7 @@ import "gvisor.dev/gvisor/pkg/tcpip/buffer" func (pk *PacketBuffer) beforeSave() { // Non-Data fields may be slices of the Data field. This causes // problems for SR, so during save we make each header independent. + pk.Header = pk.Header.DeepCopy() pk.LinkHeader = append(buffer.View(nil), pk.LinkHeader...) pk.NetworkHeader = append(buffer.View(nil), pk.NetworkHeader...) pk.TransportHeader = append(buffer.View(nil), pk.TransportHeader...) diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 8357dca77..cfdd0496e 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -421,7 +421,9 @@ func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining u pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) sent := r.Stats().ICMP.V6PacketsSent - if err := r.WritePacket(nil, hdr, buffer.VectorisedView{}, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS}); err != nil { + if err := r.WritePacket(nil, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + }); err != nil { sent.Dropped.Increment() return false, err } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 494244368..5b901f947 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -344,7 +344,7 @@ func TestDADResolve(t *testing.T) { } // Check NDP packet. - checker.IPv6(t, p.Header.ToVectorisedView().First(), + checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(), checker.TTL(header.NDPHopLimit), checker.NDPNS( checker.NDPNSTargetAddress(addr1))) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 9ed9e1e7c..3f8d7312c 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -812,15 +812,15 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } else { // n doesn't have a destination endpoint. // Send the packet out of n. - hdr := buffer.NewPrependableFromView(pkt.Data.First()) + pkt.Header = buffer.NewPrependableFromView(pkt.Data.First()) pkt.Data.RemoveFirst() // TODO(b/128629022): use route.WritePacket. - if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, pkt.Data, protocol); err != nil { + if err := n.linkEP.WritePacket(&r, nil /* gso */, protocol, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + pkt.Data.Size())) + n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) } } return diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index c0026f5a3..7fd4e4a65 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -232,8 +232,9 @@ type NetworkEndpoint interface { MaxHeaderLength() uint16 // WritePacket writes a packet to the given destination address and - // protocol. - WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) *tcpip.Error + // protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have + // already been set. + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and // protocol. @@ -241,7 +242,7 @@ type NetworkEndpoint interface { // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. - WriteHeaderIncludedPacket(r *Route, payload buffer.VectorisedView, loop PacketLooping) *tcpip.Error + WriteHeaderIncludedPacket(r *Route, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID @@ -361,13 +362,15 @@ type LinkEndpoint interface { // link endpoint. LinkAddress() tcpip.LinkAddress - // WritePacket writes a packet with the given protocol through the given - // route. + // WritePacket writes a packet with the given protocol through the + // given route. It sets pkt.LinkHeader if a link layer header exists. + // pkt.NetworkHeader and pkt.TransportHeader must have already been + // set. // // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error + WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error // WritePackets writes packets with the given protocol through the // given route. @@ -379,7 +382,7 @@ type LinkEndpoint interface { // WriteRawPacket writes a packet directly to the link. The packet // should already have an ethernet header. - WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error + WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error // Attach attaches the data link layer endpoint to the network-layer // dispatcher of the stack. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 1a0a51b57..617f5a57c 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -154,17 +154,17 @@ func (r *Route) IsResolutionRequired() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - err := r.ref.ep.WritePacket(r, gso, hdr, payload, params, r.Loop) + err := r.ref.ep.WritePacket(r, gso, params, r.Loop, pkt) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + payload.Size())) + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) } return err } @@ -209,17 +209,17 @@ func (r *Route) WritePackets(gso *GSO, hdrs []PacketDescriptor, payload buffer.V // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error { +func (r *Route) WriteHeaderIncludedPacket(pkt tcpip.PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.Loop); err != nil { + if err := r.ref.ep.WriteHeaderIncludedPacket(r, r.Loop, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payload.Size())) + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Data.Size())) return nil } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 2f8d8e822..0e88643a4 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1366,10 +1366,10 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t } fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) fakeHeader.Encode(ðFields) - ethHeader := buffer.View(fakeHeader).ToVectorisedView() - ethHeader.Append(payload) + vv := buffer.View(fakeHeader).ToVectorisedView() + vv.Append(payload) - if err := nic.linkEP.WriteRawPacket(ethHeader); err != nil { + if err := nic.linkEP.WriteRawPacket(vv); err != nil { return err } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index bf1d6974c..f979e2b1a 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -122,31 +122,30 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return f.ep.Capabilities() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ // Add the protocol's header to the packet and send it to the link // endpoint. - b := hdr.Prepend(fakeNetHeaderLen) + b := pkt.Header.Prepend(fakeNetHeaderLen) b[0] = r.RemoteAddress[0] b[1] = f.id.LocalAddress[0] b[2] = byte(params.Protocol) if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) f.HandlePacket(r, tcpip.PacketBuffer{ - Data: vv, + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) } if loop&stack.PacketOut == 0 { return nil } - return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber) + return f.ep.WritePacket(r, gso, fakeNetNumber, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. @@ -154,7 +153,7 @@ func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -330,7 +329,10 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }) } func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 2cacea99a..748ce4ea5 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -83,7 +83,10 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}); err != nil { + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.View(v).ToVectorisedView(), + }); err != nil { return 0, nil, err } @@ -617,10 +620,10 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Header[0]; dst != 3 { + if dst := p.Pkt.Header.View()[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Header[1]; src != 1 { + if src := p.Pkt.Header.View()[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 70e008d36..9c40931b5 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -429,7 +429,11 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: data.ToVectorisedView(), + TransportHeader: buffer.View(icmpv4), + }) } func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { @@ -455,7 +459,11 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, hdr, dataVV, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: dataVV, + TransportHeader: buffer.View(icmpv6), + }) } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 230a1537a..5aafe2615 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -338,13 +338,18 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, switch e.NetProto { case header.IPv4ProtocolNumber: if !e.associated { - if err := route.WriteHeaderIncludedPacket(buffer.View(payloadBytes).ToVectorisedView()); err != nil { + if err := route.WriteHeaderIncludedPacket(tcpip.PacketBuffer{ + Data: buffer.View(payloadBytes).ToVectorisedView(), + }); err != nil { return 0, nil, err } break } hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { + if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.View(payloadBytes).ToVectorisedView(), + }); err != nil { return 0, nil, err } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index f24b51b91..023045ec1 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -419,8 +419,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // TODO(b/143300739): Use the userMSS of the listening socket // for accepted sockets. - switch s.flags { - case header.TCPFlagSyn: + switch { + case s.flags == header.TCPFlagSyn: opts := parseSynSegmentOptions(s) if incSynRcvdCount() { // Only handle the syn if the following conditions hold @@ -464,7 +464,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() } - case header.TCPFlagAck: + case (s.flags & header.TCPFlagAck) != 0: if e.acceptQueueIsFull() { // Silently drop the ack as the application can't accept // the connection at this point. The ack will be @@ -478,6 +478,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { } if !synCookiesInUse() { + // When not using SYN cookies, as per RFC 793, section 3.9, page 64: + // Any acknowledgment is bad if it arrives on a connection still in + // the LISTEN state. An acceptable reset segment should be formed + // for any arriving ACK-bearing segment. The RST should be + // formatted as follows: + // + // <SEQ=SEG.ACK><CTL=RST> + // // Send a reset as this is an ACK for which there is no // half open connections and we are not using cookies // yet. diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index a114c06c1..364067731 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -723,7 +723,10 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise if ttl == 0 { ttl = r.DefaultTTL() } - if err := r.WritePacket(gso, d.Hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil { + if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, tcpip.PacketBuffer{ + Header: d.Hdr, + Data: data, + }); err != nil { r.Stats().TCP.SegmentSendErrors.Increment() return err } @@ -862,6 +865,33 @@ func (e *endpoint) completeWorkerLocked() { } } +// transitionToStateCloseLocked ensures that the endpoint is +// cleaned up from the transport demuxer, "before" moving to +// StateClose. This will ensure that no packet will be +// delivered to this endpoint from the demuxer when the endpoint +// is transitioned to StateClose. +func (e *endpoint) transitionToStateCloseLocked() { + if e.state == StateClose { + return + } + e.cleanupLocked() + e.state = StateClose +} + +// tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed +// segment to any other endpoint other than the current one. This is called +// only when the endpoint is in StateClose and we want to deliver the segment +// to any other listening endpoint. We reply with RST if we cannot find one. +func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { + ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route) + if ep == nil { + replyWithReset(s) + s.decRef() + return + } + ep.(*endpoint).enqueueSegment(s) +} + func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { if e.rcv.acceptable(s.sequenceNumber, 0) { // RFC 793, page 37 states that "in all states @@ -891,12 +921,8 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // general "connection reset" signal. Enter the CLOSED state, // delete the TCB, and return. case StateCloseWait: - e.state = StateClose + e.transitionToStateCloseLocked() e.HardError = tcpip.ErrAborted - // We need to set this explicitly here because otherwise - // the port registrations will not be released till the - // endpoint is actively closed by the application. - e.workerCleanup = true e.mu.Unlock() return false, nil default: @@ -912,6 +938,20 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { func (e *endpoint) handleSegments() *tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { + e.mu.RLock() + state := e.state + e.mu.RUnlock() + if state == StateClose { + // When we get into StateClose while processing from the queue, + // return immediately and let the protocolMainloop handle it. + // + // We can reach StateClose only while processing a previous segment + // or a notification from the protocolMainLoop (caller goroutine). + // This means that with this return, the segment dequeue below can + // never occur on a closed endpoint. + return nil + } + s := e.segmentQueue.dequeue() if s == nil { checkRequeue = false @@ -1157,7 +1197,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // to the TCP_FIN_WAIT2 timeout was hit. Just // mark the socket as closed. e.mu.Lock() - e.state = StateClose + e.transitionToStateCloseLocked() e.mu.Unlock() return nil }, @@ -1229,7 +1269,9 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return err } } - if e.state != StateError { + if e.state != StateClose && e.state != StateError { + // Only block the worker if the endpoint + // is not in closed state or error state. close(e.drainDone) <-e.undrain } @@ -1316,12 +1358,21 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { if e.state != StateError { e.stack.Stats().TCP.EstablishedResets.Increment() e.stack.Stats().TCP.CurrentEstablished.Decrement() - e.state = StateClose + e.transitionToStateCloseLocked() } // Lock released below. epilogue() + // epilogue removes the endpoint from the transport-demuxer and + // unlocks e.mu. Now that no new segments can get enqueued to this + // endpoint, try to re-match the segment to a different endpoint + // as the current endpoint is closed. + for !e.segmentQueue.empty() { + s := e.segmentQueue.dequeue() + e.tryDeliverSegmentFromClosedEndpoint(s) + } + // A new SYN was received during TIME_WAIT and we need to abort // the timewait and redirect the segment to the listener queue if reuseTW != nil { diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 068b90fb6..857dc445f 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -218,7 +218,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum case StateClosing: r.ep.state = StateTimeWait case StateLastAck: - r.ep.state = StateClose + r.ep.transitionToStateCloseLocked() } r.ep.mu.Unlock() } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index b443fe9dc..64f765c70 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -454,6 +454,112 @@ func TestConnectResetAfterClose(t *testing.T) { } } +// TestClosingWithEnqueuedSegments tests handling of +// still enqueued segments when the endpoint transitions +// to StateClose. The in-flight segments would be re-enqueued +// to a any listening endpoint. +func TestClosingWithEnqueuedSegments(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + ep := c.EP + c.EP = nil + + if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Send a FIN for ESTABLISHED --> CLOSED-WAIT + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Get the ACK for the FIN we sent. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Close the application endpoint for CLOSE_WAIT --> LAST_ACK + ep.Close() + + // Get the FIN + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + + if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Pause the endpoint`s protocolMainLoop. + ep.(interface{ StopWork() }).StopWork() + + // Enqueue last ACK followed by an ACK matching the endpoint + // + // Send Last ACK for LAST_ACK --> CLOSED + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 791, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Send a packet with ACK set, this would generate RST when + // not using SYN cookies as in this test. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 792, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Unpause endpoint`s protocolMainLoop. + ep.(interface{ ResumeWork() }).ResumeWork() + + // Wait for the protocolMainLoop to resume and update state. + time.Sleep(1 * time.Millisecond) + + // Expect the endpoint to be closed. + if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Check if the endpoint was moved to CLOSED and netstack a reset in + // response to the ACK packet that we sent after last-ACK. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(793), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + ), + ) +} + func TestSimpleReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -686,6 +792,96 @@ func TestSendRstOnListenerRxSynAckV6(t *testing.T) { checker.SeqNum(200))) } +func TestSendRstOnListenerRxAckV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1 /* epRcvBuf */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.SeqNum(200))) +} + +func TestSendRstOnListenerRxAckV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true /* v6Only */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.SeqNum(200))) +} + +// TestListenShutdown tests for the listening endpoint not processing +// any receive when it is on read shutdown. +func TestListenShutdown(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1 /* epRcvBuf */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatal("Shutdown failed:", err) + } + + // Wait for the endpoint state to be propagated. + time.Sleep(10 * time.Millisecond) + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: 100, + AckNum: 200, + }) + + c.CheckNoPacket("Packet received when listening socket was shutdown") +} + func TestTOSV4(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 0a733fa94..04fdaaed1 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -236,9 +236,9 @@ func (c *Context) GetPacket() []byte { if p.Proto != ipv4.ProtocolNumber { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) @@ -264,9 +264,9 @@ func (c *Context) GetPacketNonBlocking() []byte { if p.Proto != ipv4.ProtocolNumber { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) return b @@ -488,9 +488,9 @@ func (c *Context) GetV6Packet() []byte { if p.Proto != ipv6.ProtocolNumber { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size()) + copy(b, p.Pkt.Header.View()) + copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView()) checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) return b diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index dda7af910..2d97d1398 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -817,7 +817,10 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u if useDefaultTTL { ttl = r.DefaultTTL() } - if err := r.WritePacket(nil /* gso */, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, tcpip.PacketBuffer{ + Header: hdr, + Data: data, + }); err != nil { r.Stats().UDP.PacketSendErrors.Increment() return err } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 43f11b700..259c3072a 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -135,7 +135,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv4DstUnreachable) pkt.SetCode(header.ICMPv4PortUnreachable) pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) - r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) case header.IPv6AddressSize: if !r.Stack().AllowICMPMessage() { @@ -169,7 +172,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv6DstUnreachable) pkt.SetCode(header.ICMPv6PortUnreachable) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) - r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) } return true } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 30ee9801b..7051a7a9c 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -356,9 +356,9 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw if p.Proto != flow.netProto() { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) h := flow.header4Tuple(outgoing) checkers := append( @@ -1453,8 +1453,8 @@ func TestV4UnknownDestination(t *testing.T) { select { case p := <-c.linkEP.C: var pkt []byte - pkt = append(pkt, p.Header...) - pkt = append(pkt, p.Payload...) + pkt = append(pkt, p.Pkt.Header.View()...) + pkt = append(pkt, p.Pkt.Data.ToView()...) if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) } @@ -1527,8 +1527,8 @@ func TestV6UnknownDestination(t *testing.T) { select { case p := <-c.linkEP.C: var pkt []byte - pkt = append(pkt, p.Header...) - pkt = append(pkt, p.Payload...) + pkt = append(pkt, p.Pkt.Header.View()...) + pkt = append(pkt, p.Pkt.Data.ToView()...) if got, want := len(pkt), header.IPv6MinimumMTU; got > want { t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) } diff --git a/test/runtimes/images/proctor/proctor.go b/test/runtimes/images/proctor/proctor.go index e6178e82b..b54abe434 100644 --- a/test/runtimes/images/proctor/proctor.go +++ b/test/runtimes/images/proctor/proctor.go @@ -39,10 +39,10 @@ type TestRunner interface { } var ( - runtime = flag.String("runtime", "", "name of runtime") - list = flag.Bool("list", false, "list all available tests") - test = flag.String("test", "", "run a single test from the list of available tests") - pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children") + runtime = flag.String("runtime", "", "name of runtime") + list = flag.Bool("list", false, "list all available tests") + testName = flag.String("test", "", "run a single test from the list of available tests") + pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children") ) func main() { @@ -74,14 +74,23 @@ func main() { return } - // Run a single test. - if *test == "" { - log.Fatalf("test flag must be provided") + var tests []string + if *testName == "" { + // Run every test. + tests, err = tr.ListTests() + if err != nil { + log.Fatalf("failed to get all tests: %v", err) + } + } else { + // Run a single test. + tests = []string{*testName} } - cmd := tr.TestCmd(*test) - cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr - if err := cmd.Run(); err != nil { - log.Fatalf("FAIL: %v", err) + for _, test := range tests { + cmd := tr.TestCmd(test) + cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr + if err := cmd.Run(); err != nil { + log.Fatalf("FAIL: %v", err) + } } } diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc index 2b1df52ce..267ae19f6 100644 --- a/test/syscalls/linux/open.cc +++ b/test/syscalls/linux/open.cc @@ -73,6 +73,28 @@ class OpenTest : public FileTest { const std::string test_data_ = "hello world\n"; }; +TEST_F(OpenTest, OTrunc) { + auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); + ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); + ASSERT_THAT(open(dirpath.c_str(), O_TRUNC, 0666), + SyscallFailsWithErrno(EISDIR)); +} + +TEST_F(OpenTest, OTruncAndReadOnlyDir) { + auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); + ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); + ASSERT_THAT(open(dirpath.c_str(), O_TRUNC | O_RDONLY, 0666), + SyscallFailsWithErrno(EISDIR)); +} + +TEST_F(OpenTest, OTruncAndReadOnlyFile) { + auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncfile"); + const FileDescriptor existing = + ASSERT_NO_ERRNO_AND_VALUE(Open(dirpath.c_str(), O_RDWR | O_CREAT, 0666)); + const FileDescriptor otrunc = ASSERT_NO_ERRNO_AND_VALUE( + Open(dirpath.c_str(), O_TRUNC | O_RDONLY, 0666)); +} + TEST_F(OpenTest, ReadOnly) { char buf; const FileDescriptor ro_file = diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc index e5a85ef9d..431733dbe 100644 --- a/test/syscalls/linux/open_create.cc +++ b/test/syscalls/linux/open_create.cc @@ -88,6 +88,30 @@ TEST(CreateTest, CreateExclusively) { SyscallFailsWithErrno(EEXIST)); } +TEST(CreateTeast, CreatWithOTrunc) { + std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); + ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); + ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC, 0666), + SyscallFailsWithErrno(EISDIR)); +} + +TEST(CreateTeast, CreatDirWithOTruncAndReadOnly) { + std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); + ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); + ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666), + SyscallFailsWithErrno(EISDIR)); +} + +TEST(CreateTeast, CreatFileWithOTruncAndReadOnly) { + std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncfile"); + int dirfd; + ASSERT_THAT(dirfd = open(dirpath.c_str(), O_RDWR | O_CREAT, 0666), + SyscallSucceeds()); + ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666), + SyscallSucceeds()); + ASSERT_THAT(close(dirfd), SyscallSucceeds()); +} + TEST(CreateTest, CreateFailsOnUnpermittedDir) { // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to // always override directory permissions. diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index e4c030bbb..512de5ee0 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -183,7 +183,8 @@ PosixError WithSubprocess(SubprocessCallback const& running, siginfo_t info; // Wait until the child process has exited (WEXITED flag) but don't // reap the child (WNOWAIT flag). - waitid(P_PID, child_pid, &info, WNOWAIT | WEXITED); + EXPECT_THAT(waitid(P_PID, child_pid, &info, WNOWAIT | WEXITED), + SyscallSucceeds()); if (zombied) { // Arg of "Z" refers to a Zombied Process. diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc index 99a0df235..dafe64d20 100644 --- a/test/syscalls/linux/pty.cc +++ b/test/syscalls/linux/pty.cc @@ -70,6 +70,8 @@ constexpr absl::Duration kTimeout = absl::Seconds(20); // The maximum line size in bytes returned per read from a pty file. constexpr int kMaxLineSize = 4096; +constexpr char kMasterPath[] = "/dev/ptmx"; + // glibc defines its own, different, version of struct termios. We care about // what the kernel does, not glibc. #define KERNEL_NCCS 19 @@ -376,9 +378,25 @@ PosixErrorOr<size_t> PollAndReadFd(int fd, void* buf, size_t count, return PosixError(ETIMEDOUT, "Poll timed out"); } +TEST(PtyTrunc, Truncate) { + // Opening PTYs with O_TRUNC shouldn't cause an error, but calls to + // (f)truncate should. + FileDescriptor master = + ASSERT_NO_ERRNO_AND_VALUE(Open(kMasterPath, O_RDWR | O_TRUNC)); + int n = ASSERT_NO_ERRNO_AND_VALUE(SlaveID(master)); + std::string spath = absl::StrCat("/dev/pts/", n); + FileDescriptor slave = + ASSERT_NO_ERRNO_AND_VALUE(Open(spath, O_RDWR | O_NONBLOCK | O_TRUNC)); + + EXPECT_THAT(truncate(kMasterPath, 0), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(truncate(spath.c_str(), 0), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(ftruncate(master.get(), 0), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(ftruncate(slave.get(), 0), SyscallFailsWithErrno(EINVAL)); +} + TEST(BasicPtyTest, StatUnopenedMaster) { struct stat s; - ASSERT_THAT(stat("/dev/ptmx", &s), SyscallSucceeds()); + ASSERT_THAT(stat(kMasterPath, &s), SyscallSucceeds()); EXPECT_EQ(s.st_rdev, makedev(TTYAUX_MAJOR, kPtmxMinor)); EXPECT_EQ(s.st_size, 0); diff --git a/test/syscalls/linux/socket_ip_unbound.cc b/test/syscalls/linux/socket_ip_unbound.cc index b02872308..b6754111f 100644 --- a/test/syscalls/linux/socket_ip_unbound.cc +++ b/test/syscalls/linux/socket_ip_unbound.cc @@ -354,6 +354,38 @@ TEST_P(IPUnboundSocketTest, InvalidNegativeTOS) { EXPECT_EQ(get, expect); } +TEST_P(IPUnboundSocketTest, NullTOS) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + TOSOption t = GetTOSOption(GetParam().domain); + int set_sz = sizeof(int); + if (GetParam().domain == AF_INET) { + EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, nullptr, set_sz), + SyscallFailsWithErrno(EFAULT)); + } else { // AF_INET6 + // The AF_INET6 behavior is not yet compatible. gVisor will try to read + // optval from user memory at syscall handler, it needs substantial + // refactoring to implement this behavior just for IPv6. + if (IsRunningOnGvisor()) { + EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, nullptr, set_sz), + SyscallFailsWithErrno(EFAULT)); + } else { + // Linux's IPv6 stack treats nullptr optval as input of 0, so the call + // succeeds. (net/ipv6/ipv6_sockglue.c, do_ipv6_setsockopt()) + // + // Linux's implementation would need fixing as passing a nullptr as optval + // and non-zero optlen may not be valid. + EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, nullptr, set_sz), + SyscallSucceedsWithValue(0)); + } + } + socklen_t get_sz = sizeof(int); + EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, nullptr, &get_sz), + SyscallFailsWithErrno(EFAULT)); + int get = -1; + EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, nullptr), + SyscallFailsWithErrno(EFAULT)); +} + INSTANTIATE_TEST_SUITE_P( IPUnboundSockets, IPUnboundSocketTest, ::testing::ValuesIn(VecCat<SocketKind>(VecCat<SocketKind>( diff --git a/test/util/BUILD b/test/util/BUILD index 5d2a9cc2c..4526bb3f1 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -232,7 +232,13 @@ cc_library( cc_library( name = "test_util", testonly = 1, - srcs = ["test_util.cc"], + srcs = [ + "test_util.cc", + ] + select_for_linux( + [ + "test_util_impl.cc", + ], + ), hdrs = ["test_util.h"], deps = [ ":fs_util", diff --git a/test/util/pty_util.cc b/test/util/pty_util.cc index c0fd9a095..c01f916aa 100644 --- a/test/util/pty_util.cc +++ b/test/util/pty_util.cc @@ -24,6 +24,14 @@ namespace gvisor { namespace testing { PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master) { + PosixErrorOr<int> n = SlaveID(master); + if (!n.ok()) { + return PosixErrorOr<FileDescriptor>(n.error()); + } + return Open(absl::StrCat("/dev/pts/", n.ValueOrDie()), O_RDWR | O_NONBLOCK); +} + +PosixErrorOr<int> SlaveID(const FileDescriptor& master) { // Get pty index. int n; int ret = ioctl(master.get(), TIOCGPTN, &n); @@ -38,7 +46,7 @@ PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master) { return PosixError(errno, "ioctl(TIOSPTLCK) failed"); } - return Open(absl::StrCat("/dev/pts/", n), O_RDWR | O_NONBLOCK); + return n; } } // namespace testing diff --git a/test/util/pty_util.h b/test/util/pty_util.h index 367b14f15..0722da379 100644 --- a/test/util/pty_util.h +++ b/test/util/pty_util.h @@ -24,6 +24,9 @@ namespace testing { // Opens the slave end of the passed master as R/W and nonblocking. PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master); +// Get the number of the slave end of the master. +PosixErrorOr<int> SlaveID(const FileDescriptor& master); + } // namespace testing } // namespace gvisor diff --git a/test/util/test_util.cc b/test/util/test_util.cc index ba0dcf7d0..9cb050735 100644 --- a/test/util/test_util.cc +++ b/test/util/test_util.cc @@ -116,9 +116,6 @@ PosixErrorOr<KernelVersion> GetKernelVersion() { return ParseKernelVersion(buf.release); } -void SetupGvisorDeathTest() { -} - std::string CPUSetToString(const cpu_set_t& set, size_t cpus) { std::string str = "cpuset["; for (unsigned int n = 0; n < cpus; n++) { @@ -224,15 +221,5 @@ bool Equivalent(uint64_t current, uint64_t target, double tolerance) { return abs_diff <= static_cast<uint64_t>(tolerance * target); } -void TestInit(int* argc, char*** argv) { - ::testing::InitGoogleTest(argc, *argv); - ::absl::ParseCommandLine(*argc, *argv); - - // Always mask SIGPIPE as it's common and tests aren't expected to handle it. - struct sigaction sa = {}; - sa.sa_handler = SIG_IGN; - TEST_CHECK(sigaction(SIGPIPE, &sa, nullptr) == 0); -} - } // namespace testing } // namespace gvisor diff --git a/test/util/test_util.h b/test/util/test_util.h index b9d2dc2ba..dc30575b8 100644 --- a/test/util/test_util.h +++ b/test/util/test_util.h @@ -221,7 +221,9 @@ enum class Platform { bool IsRunningOnGvisor(); Platform GvisorPlatform(); +#ifdef __linux__ void SetupGvisorDeathTest(); +#endif struct KernelVersion { int major; diff --git a/test/util/test_util_impl.cc b/test/util/test_util_impl.cc new file mode 100644 index 000000000..ba7c0a85b --- /dev/null +++ b/test/util/test_util_impl.cc @@ -0,0 +1,38 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <signal.h> + +#include "gtest/gtest.h" +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "test/util/logging.h" + +namespace gvisor { +namespace testing { + +void SetupGvisorDeathTest() {} + +void TestInit(int* argc, char*** argv) { + ::testing::InitGoogleTest(argc, *argv); + ::absl::ParseCommandLine(*argc, *argv); + + // Always mask SIGPIPE as it's common and tests aren't expected to handle it. + struct sigaction sa = {}; + sa.sa_handler = SIG_IGN; + TEST_CHECK(sigaction(SIGPIPE, &sa, nullptr) == 0); +} + +} // namespace testing +} // namespace gvisor diff --git a/tools/tag_release.sh b/tools/tag_release.sh index 9d5a60583..f33b902d6 100755 --- a/tools/tag_release.sh +++ b/tools/tag_release.sh @@ -64,5 +64,6 @@ fi # Tag the given commit (annotated, to record the committer). declare -r tag="release-${release}" -(git tag -a "${tag}" "${commit}" && git push origin tag "${tag}") || \ +(git tag -m "Release ${release}" -a "${tag}" "${commit}" && \ + git push origin tag "${tag}") || \ (git tag -d "${tag}" && false) |