diff options
Diffstat (limited to 'pkg/tcpip/link')
23 files changed, 306 insertions, 300 deletions
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index a068d93a4..cd76272de 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -229,7 +229,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { p := PacketInfo{ Pkt: pkt, Proto: protocol, @@ -243,7 +243,7 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip } // WritePackets stores outbound packets into the channel. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { p := PacketInfo{ diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go index 2f2d9d4ac..d873766a6 100644 --- a/pkg/tcpip/link/ethernet/ethernet.go +++ b/pkg/tcpip/link/ethernet/ethernet.go @@ -61,13 +61,13 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { } // WritePacket implements stack.LinkEndpoint. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt) return e.Endpoint.WritePacket(r, gso, proto, pkt) } // WritePackets implements stack.LinkEndpoint. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { linkAddr := e.Endpoint.LinkAddress() for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 10072eac1..ae1394ebf 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -35,7 +35,6 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", - "//pkg/tcpip/link/rawfile", "//pkg/tcpip/stack", "@com_github_google_go_cmp//cmp:go_default_library", ], diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index f86c383d8..0164d851b 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -57,7 +57,7 @@ import ( // linkDispatcher reads packets from the link FD and dispatches them to the // NetworkDispatcher. type linkDispatcher interface { - dispatch() (bool, *tcpip.Error) + dispatch() (bool, tcpip.Error) } // PacketDispatchMode are the various supported methods of receiving and @@ -118,7 +118,7 @@ type endpoint struct { // closed is a function to be called when the FD's peer (if any) closes // its end of the communication pipe. - closed func(*tcpip.Error) + closed func(tcpip.Error) inboundDispatchers []linkDispatcher dispatcher stack.NetworkDispatcher @@ -149,7 +149,7 @@ type Options struct { // ClosedFunc is a function to be called when an endpoint's peer (if // any) closes its end of the communication pipe. - ClosedFunc func(*tcpip.Error) + ClosedFunc func(tcpip.Error) // Address is the link address for this endpoint. Only used if // EthernetHeader is true. @@ -411,7 +411,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if e.hdrSize > 0 { e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) } @@ -451,7 +451,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip return rawfile.NonBlockingWriteIovec(fd, builder.Build()) } -func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tcpip.Error) { +func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcpip.Error) { // Send a batch of packets through batchFD. mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch)) for _, pkt := range batch { @@ -518,7 +518,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc // - pkt.EgressRoute // - pkt.GSOOptions // - pkt.NetworkProtocolNumber -func (e *endpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) { // Preallocate to avoid repeated reallocation as we append to batch. // batchSz is 47 because when SWGSO is in use then a single 65KB TCP // segment can get split into 46 segments of 1420 bytes and a single 216 @@ -562,13 +562,13 @@ func viewsEqual(vs1, vs2 []buffer.View) bool { } // InjectOutobund implements stack.InjectableEndpoint.InjectOutbound. -func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { +func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error { return rawfile.NonBlockingWrite(e.fds[0], packet) } // dispatchLoop reads packets from the file descriptor in a loop and dispatches // them to the network stack. -func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) *tcpip.Error { +func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) tcpip.Error { for { cont, err := inboundDispatcher.dispatch() if err != nil || !cont { diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 90da22d34..e82371798 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -30,7 +30,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -96,7 +95,7 @@ func newContext(t *testing.T, opt *Options) *context { } done := make(chan struct{}, 2) - opt.ClosedFunc = func(*tcpip.Error) { + opt.ClosedFunc = func(tcpip.Error) { done <- struct{}{} } @@ -465,67 +464,85 @@ var capLengthTestCases = []struct { config: []int{1, 2, 3}, n: 3, wantUsed: 2, - wantLengths: []int{1, 2, 3}, + wantLengths: []int{1, 2}, }, } -func TestReadVDispatcherCapLength(t *testing.T) { +func TestIovecBuffer(t *testing.T) { for _, c := range capLengthTestCases { - // fd does not matter for this test. - d := readVDispatcher{fd: -1, e: &endpoint{}} - d.views = make([]buffer.View, len(c.config)) - d.iovecs = make([]syscall.Iovec, len(c.config)) - d.allocateViews(c.config) - - used := d.capViews(c.n, c.config) - if used != c.wantUsed { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed) - } - lengths := make([]int, len(d.views)) - for i, v := range d.views { - lengths[i] = len(v) - } - if !reflect.DeepEqual(lengths, c.wantLengths) { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths) - } - } -} + t.Run(c.comment, func(t *testing.T) { + b := newIovecBuffer(c.config, false /* skipsVnetHdr */) -func TestRecvMMsgDispatcherCapLength(t *testing.T) { - for _, c := range capLengthTestCases { - d := recvMMsgDispatcher{ - fd: -1, // fd does not matter for this test. - e: &endpoint{}, - views: make([][]buffer.View, 1), - iovecs: make([][]syscall.Iovec, 1), - msgHdrs: make([]rawfile.MMsgHdr, 1), - } + // Test initial allocation. + iovecs := b.nextIovecs() + if got, want := len(iovecs), len(c.config); got != want { + t.Fatalf("len(iovecs) = %d, want %d", got, want) + } - for i := range d.views { - d.views[i] = make([]buffer.View, len(c.config)) - } - for i := range d.iovecs { - d.iovecs[i] = make([]syscall.Iovec, len(c.config)) - } - for k, msgHdr := range d.msgHdrs { - msgHdr.Msg.Iov = &d.iovecs[k][0] - msgHdr.Msg.Iovlen = uint64(len(c.config)) - } + // Make a copy as iovecs points to internal slice. We will need this state + // later. + oldIovecs := append([]syscall.Iovec(nil), iovecs...) - d.allocateViews(c.config) + // Test the views that get pulled. + vv := b.pullViews(c.n) + var lengths []int + for _, v := range vv.Views() { + lengths = append(lengths, len(v)) + } + if !reflect.DeepEqual(lengths, c.wantLengths) { + t.Errorf("Pulled view lengths = %v, want %v", lengths, c.wantLengths) + } - used := d.capViews(0, c.n, c.config) - if used != c.wantUsed { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed) - } - lengths := make([]int, len(d.views[0])) - for i, v := range d.views[0] { - lengths[i] = len(v) - } - if !reflect.DeepEqual(lengths, c.wantLengths) { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths) - } + // Test that new views get reallocated. + for i, newIov := range b.nextIovecs() { + if i < c.wantUsed { + if newIov.Base == oldIovecs[i].Base { + t.Errorf("b.views[%d] should have been reallocated", i) + } + } else { + if newIov.Base != oldIovecs[i].Base { + t.Errorf("b.views[%d] should not have been reallocated", i) + } + } + } + }) + } +} +func TestIovecBufferSkipVnetHdr(t *testing.T) { + for _, test := range []struct { + desc string + readN int + wantLen int + }{ + { + desc: "nothing read", + readN: 0, + wantLen: 0, + }, + { + desc: "smaller than vnet header", + readN: virtioNetHdrSize - 1, + wantLen: 0, + }, + { + desc: "header skipped", + readN: virtioNetHdrSize + 100, + wantLen: 100, + }, + } { + t.Run(test.desc, func(t *testing.T) { + b := newIovecBuffer([]int{10, 20, 50, 50}, true) + // Pretend a read happend. + b.nextIovecs() + vv := b.pullViews(test.readN) + if got, want := vv.Size(), test.wantLen; got != want { + t.Errorf("b.pullView(%d).Size() = %d; want %d", test.readN, got, want) + } + if got, want := len(vv.ToOwnedView()), test.wantLen; got != want { + t.Errorf("b.pullView(%d).ToOwnedView() has length %d; want %d", test.readN, got, want) + } + }) } } diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index c475dda20..a2b63fe6b 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -129,7 +129,7 @@ type packetMMapDispatcher struct { ringOffset int } -func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) { +func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, tcpip.Error) { hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:]) for hdr.tpStatus()&tpStatusUser == 0 { event := rawfile.PollEvent{ @@ -163,7 +163,7 @@ func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) { // dispatch reads packets from an mmaped ring buffer and dispatches them to the // network stack. -func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { +func (d *packetMMapDispatcher) dispatch() (bool, tcpip.Error) { pkt, err := d.readMMappedPacket() if err != nil { return false, err diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index 8c3ca86d6..ecae1ad2d 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -29,92 +29,124 @@ import ( // BufConfig defines the shape of the vectorised view used to read packets from the NIC. var BufConfig = []int{128, 256, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768} -// readVDispatcher uses readv() system call to read inbound packets and -// dispatches them. -type readVDispatcher struct { - // fd is the file descriptor used to send and receive packets. - fd int - - // e is the endpoint this dispatcher is attached to. - e *endpoint - +type iovecBuffer struct { // views are the actual buffers that hold the packet contents. views []buffer.View // iovecs are initialized with base pointers/len of the corresponding - // entries in the views defined above, except when GSO is enabled then - // the first iovec points to a buffer for the vnet header which is - // stripped before the views are passed up the stack for further + // entries in the views defined above, except when GSO is enabled + // (skipsVnetHdr) then the first iovec points to a buffer for the vnet header + // which is stripped before the views are passed up the stack for further // processing. iovecs []syscall.Iovec + + // sizes is an array of buffer sizes for the underlying views. sizes is + // immutable. + sizes []int + + // skipsVnetHdr is true if virtioNetHdr is to skipped. + skipsVnetHdr bool } -func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { - d := &readVDispatcher{fd: fd, e: e} - d.views = make([]buffer.View, len(BufConfig)) - iovLen := len(BufConfig) - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - iovLen++ +func newIovecBuffer(sizes []int, skipsVnetHdr bool) *iovecBuffer { + b := &iovecBuffer{ + views: make([]buffer.View, len(sizes)), + sizes: sizes, + skipsVnetHdr: skipsVnetHdr, } - d.iovecs = make([]syscall.Iovec, iovLen) - return d, nil + niov := len(b.views) + if b.skipsVnetHdr { + niov++ + } + b.iovecs = make([]syscall.Iovec, niov) + return b } -func (d *readVDispatcher) allocateViews(bufConfig []int) { - var vnetHdr [virtioNetHdrSize]byte +func (b *iovecBuffer) nextIovecs() []syscall.Iovec { vnetHdrOff := 0 - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if b.skipsVnetHdr { + var vnetHdr [virtioNetHdrSize]byte // The kernel adds virtioNetHdr before each packet, but // we don't use it, so so we allocate a buffer for it, // add it in iovecs but don't add it in a view. - d.iovecs[0] = syscall.Iovec{ + b.iovecs[0] = syscall.Iovec{ Base: &vnetHdr[0], Len: uint64(virtioNetHdrSize), } vnetHdrOff++ } - for i := 0; i < len(bufConfig); i++ { - if d.views[i] != nil { + for i := range b.views { + if b.views[i] != nil { break } - b := buffer.NewView(bufConfig[i]) - d.views[i] = b - d.iovecs[i+vnetHdrOff] = syscall.Iovec{ - Base: &b[0], - Len: uint64(len(b)), + v := buffer.NewView(b.sizes[i]) + b.views[i] = v + b.iovecs[i+vnetHdrOff] = syscall.Iovec{ + Base: &v[0], + Len: uint64(len(v)), } } + return b.iovecs } -func (d *readVDispatcher) capViews(n int, buffers []int) int { +func (b *iovecBuffer) pullViews(n int) buffer.VectorisedView { + var views []buffer.View c := 0 - for i, s := range buffers { - c += s + if b.skipsVnetHdr { + c += virtioNetHdrSize if c >= n { - d.views[i].CapLength(s - (c - n)) - return i + 1 + // Nothing in the packet. + return buffer.NewVectorisedView(0, nil) + } + } + for i, v := range b.views { + c += len(v) + if c >= n { + b.views[i].CapLength(len(v) - (c - n)) + views = append([]buffer.View(nil), b.views[:i+1]...) + break } } - return len(buffers) + // Remove the first len(views) used views from the state. + for i := range views { + b.views[i] = nil + } + if b.skipsVnetHdr { + // Exclude the size of the vnet header. + n -= virtioNetHdrSize + } + return buffer.NewVectorisedView(n, views) } -// dispatch reads one packet from the file descriptor and dispatches it. -func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { - d.allocateViews(BufConfig) +// readVDispatcher uses readv() system call to read inbound packets and +// dispatches them. +type readVDispatcher struct { + // fd is the file descriptor used to send and receive packets. + fd int + + // e is the endpoint this dispatcher is attached to. + e *endpoint + + // buf is the iovec buffer that contains the packet contents. + buf *iovecBuffer +} + +func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { + d := &readVDispatcher{fd: fd, e: e} + skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 + d.buf = newIovecBuffer(BufConfig, skipsVnetHdr) + return d, nil +} - n, err := rawfile.BlockingReadv(d.fd, d.iovecs) +// dispatch reads one packet from the file descriptor and dispatches it. +func (d *readVDispatcher) dispatch() (bool, tcpip.Error) { + n, err := rawfile.BlockingReadv(d.fd, d.buf.nextIovecs()) if n == 0 || err != nil { return false, err } - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - // Skip virtioNetHdr which is added before each packet, it - // isn't used and it isn't in a view. - n -= virtioNetHdrSize - } - used := d.capViews(n, BufConfig) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), + Data: d.buf.pullViews(n), }) var ( @@ -133,7 +165,12 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } else { // We don't get any indication of what the packet is, so try to guess // if it's an IPv4 or IPv6 packet. - switch header.IPVersion(d.views[0]) { + // IP version information is at the first octet, so pulling up 1 byte. + h, ok := pkt.Data.PullUp(1) + if !ok { + return true, nil + } + switch header.IPVersion(h) { case header.IPv4Version: p = header.IPv4ProtocolNumber case header.IPv6Version: @@ -145,11 +182,6 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) - // Prepare e.views for another packet: release used views. - for i := 0; i < used; i++ { - d.views[i] = nil - } - return true, nil } @@ -162,15 +194,8 @@ type recvMMsgDispatcher struct { // e is the endpoint this dispatcher is attached to. e *endpoint - // views is an array of array of buffers that contain packet contents. - views [][]buffer.View - - // iovecs is an array of array of iovec records where each iovec base - // pointer and length are initialzed to the corresponding view above, - // except when GSO is enabled then the first iovec in each array of - // iovecs points to a buffer for the vnet header which is stripped - // before the views are passed up the stack for further processing. - iovecs [][]syscall.Iovec + // bufs is an array of iovec buffers that contain packet contents. + bufs []*iovecBuffer // msgHdrs is an array of MMsgHdr objects where each MMsghdr is used to // reference an array of iovecs in the iovecs field defined above. This @@ -187,74 +212,32 @@ const ( func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) { d := &recvMMsgDispatcher{ - fd: fd, - e: e, - } - d.views = make([][]buffer.View, MaxMsgsPerRecv) - for i := range d.views { - d.views[i] = make([]buffer.View, len(BufConfig)) - } - d.iovecs = make([][]syscall.Iovec, MaxMsgsPerRecv) - iovLen := len(BufConfig) - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - // virtioNetHdr is prepended before each packet. - iovLen++ + fd: fd, + e: e, + bufs: make([]*iovecBuffer, MaxMsgsPerRecv), + msgHdrs: make([]rawfile.MMsgHdr, MaxMsgsPerRecv), } - for i := range d.iovecs { - d.iovecs[i] = make([]syscall.Iovec, iovLen) - } - d.msgHdrs = make([]rawfile.MMsgHdr, MaxMsgsPerRecv) - for i := range d.msgHdrs { - d.msgHdrs[i].Msg.Iov = &d.iovecs[i][0] - d.msgHdrs[i].Msg.Iovlen = uint64(iovLen) + skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 + for i := range d.bufs { + d.bufs[i] = newIovecBuffer(BufConfig, skipsVnetHdr) } return d, nil } -func (d *recvMMsgDispatcher) capViews(k, n int, buffers []int) int { - c := 0 - for i, s := range buffers { - c += s - if c >= n { - d.views[k][i].CapLength(s - (c - n)) - return i + 1 - } - } - return len(buffers) -} - -func (d *recvMMsgDispatcher) allocateViews(bufConfig []int) { - for k := 0; k < len(d.views); k++ { - var vnetHdr [virtioNetHdrSize]byte - vnetHdrOff := 0 - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - // The kernel adds virtioNetHdr before each packet, but - // we don't use it, so so we allocate a buffer for it, - // add it in iovecs but don't add it in a view. - d.iovecs[k][0] = syscall.Iovec{ - Base: &vnetHdr[0], - Len: uint64(virtioNetHdrSize), - } - vnetHdrOff++ - } - for i := 0; i < len(bufConfig); i++ { - if d.views[k][i] != nil { - break - } - b := buffer.NewView(bufConfig[i]) - d.views[k][i] = b - d.iovecs[k][i+vnetHdrOff] = syscall.Iovec{ - Base: &b[0], - Len: uint64(len(b)), - } - } - } -} - // recvMMsgDispatch reads more than one packet at a time from the file // descriptor and dispatches it. -func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { - d.allocateViews(BufConfig) +func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) { + // Fill message headers. + for k := range d.msgHdrs { + if d.msgHdrs[k].Msg.Iovlen > 0 { + break + } + iovecs := d.bufs[k].nextIovecs() + iovLen := len(iovecs) + d.msgHdrs[k].Len = 0 + d.msgHdrs[k].Msg.Iov = &iovecs[0] + d.msgHdrs[k].Msg.Iovlen = uint64(iovLen) + } nMsgs, err := rawfile.BlockingRecvMMsg(d.fd, d.msgHdrs) if err != nil { @@ -263,15 +246,14 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { // Process each of received packets. for k := 0; k < nMsgs; k++ { n := int(d.msgHdrs[k].Len) - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - n -= virtioNetHdrSize - } - used := d.capViews(k, int(n), BufConfig) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), + Data: d.bufs[k].pullViews(n), }) + // Mark that this iovec has been processed. + d.msgHdrs[k].Msg.Iovlen = 0 + var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress @@ -288,26 +270,24 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } else { // We don't get any indication of what the packet is, so try to guess // if it's an IPv4 or IPv6 packet. - switch header.IPVersion(d.views[k][0]) { + // IP version information is at the first octet, so pulling up 1 byte. + h, ok := pkt.Data.PullUp(1) + if !ok { + // Skip this packet. + continue + } + switch header.IPVersion(h) { case header.IPv4Version: p = header.IPv4ProtocolNumber case header.IPv6Version: p = header.IPv6ProtocolNumber default: - return true, nil + // Skip this packet. + continue } } d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) - - // Prepare e.views for another packet: release used views. - for i := 0; i < used; i++ { - d.views[k][i] = nil - } - } - - for k := 0; k < nMsgs; k++ { - d.msgHdrs[k].Len = 0 } return true, nil diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index ac6a6be87..691467870 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -76,7 +76,7 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { // Construct data as the unparsed portion for the loopback packet. data := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) @@ -92,7 +92,7 @@ func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.N } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 316f508e6..668f72eee 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -87,10 +87,10 @@ func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, // WritePackets writes outbound packets to the appropriate // LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if // r.RemoteAddress has a route registered in this endpoint. -func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { endpoint, ok := m.routes[r.RemoteAddress] if !ok { - return 0, tcpip.ErrNoRoute + return 0, &tcpip.ErrNoRoute{} } return endpoint.WritePackets(r, gso, pkts, protocol) } @@ -98,19 +98,19 @@ func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkt // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint // based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a // route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (m *InjectableEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if endpoint, ok := m.routes[r.RemoteAddress]; ok { return endpoint.WritePacket(r, gso, protocol, pkt) } - return tcpip.ErrNoRoute + return &tcpip.ErrNoRoute{} } // InjectOutbound writes outbound packets to the appropriate // LinkInjectableEndpoint based on the dest address. -func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { +func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error { endpoint, ok := m.routes[dest] if !ok { - return tcpip.ErrNoRoute + return &tcpip.ErrNoRoute{} } return endpoint.InjectOutbound(dest, packet) } diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go index 814a54f23..97ad9fdd5 100644 --- a/pkg/tcpip/link/nested/nested.go +++ b/pkg/tcpip/link/nested/nested.go @@ -113,12 +113,12 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket implements stack.LinkEndpoint. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { return e.child.WritePacket(r, gso, protocol, pkt) } // WritePackets implements stack.LinkEndpoint. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { return e.child.WritePackets(r, gso, pkts, protocol) } diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go index c95cdd681..6cbe18a56 100644 --- a/pkg/tcpip/link/packetsocket/endpoint.go +++ b/pkg/tcpip/link/packetsocket/endpoint.go @@ -35,13 +35,13 @@ func New(lower stack.LinkEndpoint) stack.LinkEndpoint { } // WritePacket implements stack.LinkEndpoint.WritePacket. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt) return e.Endpoint.WritePacket(r, gso, protocol, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt) } diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index 36aa9055c..bbe84f220 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -75,7 +75,7 @@ func (e *Endpoint) deliverPackets(r stack.RouteInfo, proto tcpip.NetworkProtocol } // WritePacket implements stack.LinkEndpoint. -func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if e.linked.IsAttached() { var pkts stack.PacketBufferList pkts.PushBack(pkt) @@ -86,7 +86,7 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.Netw } // WritePackets implements stack.LinkEndpoint. -func (e *Endpoint) WritePackets(r stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { if e.linked.IsAttached() { e.deliverPackets(r, proto, pkts) } diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index 03efba606..128ef6e87 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -150,7 +150,7 @@ func (e *endpoint) GSOMaxSize() uint32 { } // WritePacket implements stack.LinkEndpoint.WritePacket. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { // WritePacket caller's do not set the following fields in PacketBuffer // so we populate them here. pkt.EgressRoute = r @@ -158,7 +158,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip pkt.NetworkProtocolNumber = protocol d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] if !d.q.enqueue(pkt) { - return tcpip.ErrNoBufferSpace + return &tcpip.ErrNoBufferSpace{} } d.newPacketWaker.Assert() return nil @@ -171,7 +171,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip // - pkt.EgressRoute // - pkt.GSOOptions // - pkt.NetworkProtocolNumber -func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { enqueued := 0 for pkt := pkts.Front(); pkt != nil; { d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] @@ -180,7 +180,7 @@ func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.Pa if enqueued > 0 { d.newPacketWaker.Assert() } - return enqueued, tcpip.ErrNoBufferSpace + return enqueued, &tcpip.ErrNoBufferSpace{} } pkt = nxt enqueued++ diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 6c410c5a6..e1047da50 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -27,5 +27,6 @@ go_test( library = "rawfile", deps = [ "//pkg/tcpip", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go index 604868fd8..406b97709 100644 --- a/pkg/tcpip/link/rawfile/errors.go +++ b/pkg/tcpip/link/rawfile/errors.go @@ -17,7 +17,6 @@ package rawfile import ( - "fmt" "syscall" "gvisor.dev/gvisor/pkg/tcpip" @@ -25,48 +24,54 @@ import ( const maxErrno = 134 -var translations [maxErrno]*tcpip.Error - // TranslateErrno translate an errno from the syscall package into a -// *tcpip.Error. +// tcpip.Error. // // Valid, but unrecognized errnos will be translated to -// tcpip.ErrInvalidEndpointState (EINVAL). -func TranslateErrno(e syscall.Errno) *tcpip.Error { - if e > 0 && e < syscall.Errno(len(translations)) { - if err := translations[e]; err != nil { - return err - } - } - return tcpip.ErrInvalidEndpointState -} - -func addTranslation(host syscall.Errno, trans *tcpip.Error) { - if translations[host] != nil { - panic(fmt.Sprintf("duplicate translation for host errno %q (%d)", host.Error(), host)) +// *tcpip.ErrInvalidEndpointState (EINVAL). +func TranslateErrno(e syscall.Errno) tcpip.Error { + switch e { + case syscall.EEXIST: + return &tcpip.ErrDuplicateAddress{} + case syscall.ENETUNREACH: + return &tcpip.ErrNoRoute{} + case syscall.EINVAL: + return &tcpip.ErrInvalidEndpointState{} + case syscall.EALREADY: + return &tcpip.ErrAlreadyConnecting{} + case syscall.EISCONN: + return &tcpip.ErrAlreadyConnected{} + case syscall.EADDRINUSE: + return &tcpip.ErrPortInUse{} + case syscall.EADDRNOTAVAIL: + return &tcpip.ErrBadLocalAddress{} + case syscall.EPIPE: + return &tcpip.ErrClosedForSend{} + case syscall.EWOULDBLOCK: + return &tcpip.ErrWouldBlock{} + case syscall.ECONNREFUSED: + return &tcpip.ErrConnectionRefused{} + case syscall.ETIMEDOUT: + return &tcpip.ErrTimeout{} + case syscall.EINPROGRESS: + return &tcpip.ErrConnectStarted{} + case syscall.EDESTADDRREQ: + return &tcpip.ErrDestinationRequired{} + case syscall.ENOTSUP: + return &tcpip.ErrNotSupported{} + case syscall.ENOTTY: + return &tcpip.ErrQueueSizeNotSupported{} + case syscall.ENOTCONN: + return &tcpip.ErrNotConnected{} + case syscall.ECONNRESET: + return &tcpip.ErrConnectionReset{} + case syscall.ECONNABORTED: + return &tcpip.ErrConnectionAborted{} + case syscall.EMSGSIZE: + return &tcpip.ErrMessageTooLong{} + case syscall.ENOBUFS: + return &tcpip.ErrNoBufferSpace{} + default: + return &tcpip.ErrInvalidEndpointState{} } - translations[host] = trans -} - -func init() { - addTranslation(syscall.EEXIST, tcpip.ErrDuplicateAddress) - addTranslation(syscall.ENETUNREACH, tcpip.ErrNoRoute) - addTranslation(syscall.EINVAL, tcpip.ErrInvalidEndpointState) - addTranslation(syscall.EALREADY, tcpip.ErrAlreadyConnecting) - addTranslation(syscall.EISCONN, tcpip.ErrAlreadyConnected) - addTranslation(syscall.EADDRINUSE, tcpip.ErrPortInUse) - addTranslation(syscall.EADDRNOTAVAIL, tcpip.ErrBadLocalAddress) - addTranslation(syscall.EPIPE, tcpip.ErrClosedForSend) - addTranslation(syscall.EWOULDBLOCK, tcpip.ErrWouldBlock) - addTranslation(syscall.ECONNREFUSED, tcpip.ErrConnectionRefused) - addTranslation(syscall.ETIMEDOUT, tcpip.ErrTimeout) - addTranslation(syscall.EINPROGRESS, tcpip.ErrConnectStarted) - addTranslation(syscall.EDESTADDRREQ, tcpip.ErrDestinationRequired) - addTranslation(syscall.ENOTSUP, tcpip.ErrNotSupported) - addTranslation(syscall.ENOTTY, tcpip.ErrQueueSizeNotSupported) - addTranslation(syscall.ENOTCONN, tcpip.ErrNotConnected) - addTranslation(syscall.ECONNRESET, tcpip.ErrConnectionReset) - addTranslation(syscall.ECONNABORTED, tcpip.ErrConnectionAborted) - addTranslation(syscall.EMSGSIZE, tcpip.ErrMessageTooLong) - addTranslation(syscall.ENOBUFS, tcpip.ErrNoBufferSpace) } diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go index e4cdc66bd..61aea1744 100644 --- a/pkg/tcpip/link/rawfile/errors_test.go +++ b/pkg/tcpip/link/rawfile/errors_test.go @@ -20,34 +20,35 @@ import ( "syscall" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" ) func TestTranslateErrno(t *testing.T) { for _, test := range []struct { errno syscall.Errno - translated *tcpip.Error + translated tcpip.Error }{ { errno: syscall.Errno(0), - translated: tcpip.ErrInvalidEndpointState, + translated: &tcpip.ErrInvalidEndpointState{}, }, { errno: syscall.Errno(maxErrno), - translated: tcpip.ErrInvalidEndpointState, + translated: &tcpip.ErrInvalidEndpointState{}, }, { errno: syscall.Errno(514), - translated: tcpip.ErrInvalidEndpointState, + translated: &tcpip.ErrInvalidEndpointState{}, }, { errno: syscall.EEXIST, - translated: tcpip.ErrDuplicateAddress, + translated: &tcpip.ErrDuplicateAddress{}, }, } { got := TranslateErrno(test.errno) - if got != test.translated { - t.Errorf("TranslateErrno(%q) = %q, want %q", test.errno, got, test.translated) + if diff := cmp.Diff(test.translated, got); diff != "" { + t.Errorf("unexpected result from TranslateErrno(%q), (-want, +got):\n%s", test.errno, diff) } } } diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index f4c32c2da..06f3ee21e 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -52,7 +52,7 @@ func GetMTU(name string) (uint32, error) { // NonBlockingWrite writes the given buffer to a file descriptor. It fails if // partial data is written. -func NonBlockingWrite(fd int, buf []byte) *tcpip.Error { +func NonBlockingWrite(fd int, buf []byte) tcpip.Error { var ptr unsafe.Pointer if len(buf) > 0 { ptr = unsafe.Pointer(&buf[0]) @@ -68,7 +68,7 @@ func NonBlockingWrite(fd int, buf []byte) *tcpip.Error { // NonBlockingWriteIovec writes iovec to a file descriptor in a single syscall. // It fails if partial data is written. -func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) *tcpip.Error { +func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) tcpip.Error { iovecLen := uintptr(len(iovec)) _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen) if e != 0 { @@ -78,7 +78,7 @@ func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) *tcpip.Error { } // NonBlockingSendMMsg sends multiple messages on a socket. -func NonBlockingSendMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) { +func NonBlockingSendMMsg(fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) { n, _, e := syscall.RawSyscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0) if e != 0 { return 0, TranslateErrno(e) @@ -97,7 +97,7 @@ type PollEvent struct { // BlockingRead reads from a file descriptor that is set up as non-blocking. If // no data is available, it will block in a poll() syscall until the file // descriptor becomes readable. -func BlockingRead(fd int, b []byte) (int, *tcpip.Error) { +func BlockingRead(fd int, b []byte) (int, tcpip.Error) { for { n, _, e := syscall.RawSyscall(syscall.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b))) if e == 0 { @@ -119,7 +119,7 @@ func BlockingRead(fd int, b []byte) (int, *tcpip.Error) { // BlockingReadv reads from a file descriptor that is set up as non-blocking and // stores the data in a list of iovecs buffers. If no data is available, it will // block in a poll() syscall until the file descriptor becomes readable. -func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, *tcpip.Error) { +func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, tcpip.Error) { for { n, _, e := syscall.RawSyscall(syscall.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs))) if e == 0 { @@ -149,7 +149,7 @@ type MMsgHdr struct { // and stores the received messages in a slice of MMsgHdr structures. If no data // is available, it will block in a poll() syscall until the file descriptor // becomes readable. -func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) { +func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) { for { n, _, e := syscall.RawSyscall6(syscall.SYS_RECVMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0) if e == 0 { diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 6c937c858..2599bc406 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -203,7 +203,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) views := pkt.Views() @@ -213,14 +213,14 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.N e.mu.Unlock() if !ok { - return tcpip.ErrWouldBlock + return &tcpip.ErrWouldBlock{} } return nil } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (*endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 23242b9e0..d480ad656 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -425,8 +425,9 @@ func TestFillTxQueue(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } } @@ -493,8 +494,9 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } } @@ -538,8 +540,8 @@ func TestFillTxMemory(t *testing.T) { Data: buf.ToVectorisedView(), }) err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) - if want := tcpip.ErrWouldBlock; err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } } @@ -579,8 +581,9 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buffer.NewView(bufferSize).ToVectorisedView(), }) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 5859851d8..bd2b8d4bf 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -187,7 +187,7 @@ func (e *endpoint) dumpPacket(dir direction, gso *stack.GSO, protocol tcpip.Netw // WritePacket implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.dumpPacket(directionSend, gso, protocol, pkt) return e.Endpoint.WritePacket(r, gso, protocol, pkt) } @@ -195,7 +195,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip // WritePackets implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.dumpPacket(directionSend, gso, protocol, pkt) } diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index bfac358f4..3829ca9c9 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -149,10 +149,10 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE err := s.CreateNICWithOptions(endpoint.nicID, endpoint, stack.NICOptions{ Name: endpoint.name, }) - switch err { + switch err.(type) { case nil: return endpoint, nil - case tcpip.ErrDuplicateNICID: + case *tcpip.ErrDuplicateNICID: // Race detected: A NIC has been created in between. continue default: diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 30f1ad540..20259b285 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -108,7 +108,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket implements stack.LinkEndpoint.WritePacket. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if !e.writeGate.Enter() { return nil } @@ -121,7 +121,7 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip // WritePackets implements stack.LinkEndpoint.WritePackets. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { if !e.writeGate.Enter() { return pkts.Len(), nil } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index b139de7dd..e368a9eaa 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -69,13 +69,13 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e *countedEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { +func (e *countedEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { e.writeCount++ return nil } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) { e.writeCount += pkts.Len() return pkts.Len(), nil } |