diff options
44 files changed, 360 insertions, 435 deletions
diff --git a/pkg/dhcp/dhcp_test.go b/pkg/dhcp/dhcp_test.go index 67814683a..a187c5c2a 100644 --- a/pkg/dhcp/dhcp_test.go +++ b/pkg/dhcp/dhcp_test.go @@ -42,11 +42,7 @@ func createStack(t *testing.T) *stack.Stack { go func() { for pkt := range linkEP.C { - v := make(buffer.View, len(pkt.Header)+len(pkt.Payload)) - copy(v, pkt.Header) - copy(v[len(pkt.Header):], pkt.Payload) - vv := v.ToVectorisedView([1]buffer.View{}) - linkEP.Inject(pkt.Proto, &vv) + linkEP.Inject(pkt.Proto, buffer.NewVectorisedView(len(pkt.Header)+len(pkt.Payload), []buffer.View{pkt.Header, pkt.Payload})) } }() diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 19491fb2c..490b9c648 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -350,8 +350,7 @@ func (c *Conn) Write(b []byte) (int, error) { default: } - v := buffer.NewView(len(b)) - copy(v, b) + v := buffer.NewViewFromBytes(b) // We must handle two soft failure conditions simultaneously: // 1. Write may write nothing and return tcpip.ErrWouldBlock. diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 4a921ddcb..cea4e3657 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -45,11 +45,9 @@ func (v *View) CapLength(length int) { *v = (*v)[:length:length] } -// ToVectorisedView transforms a View in a VectorisedView from an -// already-allocated slice of View. -func (v *View) ToVectorisedView(views [1]View) VectorisedView { - views[0] = *v - return NewVectorisedView(len(*v), views[:]) +// ToVectorisedView returns a VectorisedView containing the receiver. +func (v View) ToVectorisedView() VectorisedView { + return NewVectorisedView(len(v), []View{v}) } // VectorisedView is a vectorised version of View using non contigous memory. @@ -107,21 +105,12 @@ func (vv *VectorisedView) CapLength(length int) { // Clone returns a clone of this VectorisedView. // If the buffer argument is large enough to contain all the Views of this VectorisedView, // the method will avoid allocations and use the buffer to store the Views of the clone. -func (vv *VectorisedView) Clone(buffer []View) VectorisedView { - var views []View - if len(buffer) >= len(vv.views) { - views = buffer[:len(vv.views)] - } else { - views = make([]View, len(vv.views)) - } - for i, v := range vv.views { - views[i] = v - } - return VectorisedView{views: views, size: vv.size} +func (vv VectorisedView) Clone(buffer []View) VectorisedView { + return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } // First returns the first view of the vectorised view. -func (vv *VectorisedView) First() View { +func (vv VectorisedView) First() View { if len(vv.views) == 0 { return nil } @@ -137,23 +126,13 @@ func (vv *VectorisedView) RemoveFirst() { vv.views = vv.views[1:] } -// SetSize unsafely sets the size of the VectorisedView. -func (vv *VectorisedView) SetSize(size int) { - vv.size = size -} - -// SetViews unsafely sets the views of the VectorisedView. -func (vv *VectorisedView) SetViews(views []View) { - vv.views = views -} - // Size returns the size in bytes of the entire content stored in the vectorised view. -func (vv *VectorisedView) Size() int { +func (vv VectorisedView) Size() int { return vv.size } // ToView returns a single view containing the content of the vectorised view. -func (vv *VectorisedView) ToView() View { +func (vv VectorisedView) ToView() View { u := make([]byte, 0, vv.size) for _, v := range vv.views { u = append(u, v...) @@ -162,29 +141,6 @@ func (vv *VectorisedView) ToView() View { } // Views returns the slice containing the all views. -func (vv *VectorisedView) Views() []View { +func (vv VectorisedView) Views() []View { return vv.views } - -// ByteSlice returns a slice containing the all views as a []byte. -func (vv *VectorisedView) ByteSlice() [][]byte { - s := make([][]byte, len(vv.views)) - for i := range vv.views { - s[i] = []byte(vv.views[i]) - } - return s -} - -// copy returns a deep-copy of the vectorised view. -// It is an expensive method that should be used only in tests. -func (vv *VectorisedView) copy() *VectorisedView { - uu := &VectorisedView{ - views: make([]View, len(vv.views)), - size: vv.size, - } - for i, v := range vv.views { - uu.views[i] = make(View, len(v)) - copy(uu.views[i], v) - } - return uu -} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index 57fe12360..02c264593 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -20,22 +20,33 @@ import ( "testing" ) +// copy returns a deep-copy of the vectorised view. +func (vv VectorisedView) copy() VectorisedView { + uu := VectorisedView{ + views: make([]View, 0, len(vv.views)), + size: vv.size, + } + for _, v := range vv.views { + uu.views = append(uu.views, append(View(nil), v...)) + } + return uu +} + // vv is an helper to build VectorisedView from different strings. -func vv(size int, pieces ...string) *VectorisedView { +func vv(size int, pieces ...string) VectorisedView { views := make([]View, len(pieces)) for i, p := range pieces { views[i] = []byte(p) } - vv := NewVectorisedView(size, views) - return &vv + return NewVectorisedView(size, views) } var capLengthTestCases = []struct { comment string - in *VectorisedView + in VectorisedView length int - want *VectorisedView + want VectorisedView }{ { comment: "Simple case", @@ -88,9 +99,9 @@ func TestCapLength(t *testing.T) { var trimFrontTestCases = []struct { comment string - in *VectorisedView + in VectorisedView count int - want *VectorisedView + want VectorisedView }{ { comment: "Simple case", @@ -149,7 +160,7 @@ func TestTrimFront(t *testing.T) { var toViewCases = []struct { comment string - in *VectorisedView + in VectorisedView want View }{ { @@ -181,7 +192,7 @@ func TestToView(t *testing.T) { var toCloneCases = []struct { comment string - inView *VectorisedView + inView VectorisedView inBuffer []View }{ { @@ -213,10 +224,12 @@ var toCloneCases = []struct { func TestToClone(t *testing.T) { for _, c := range toCloneCases { - got := c.inView.Clone(c.inBuffer) - if !reflect.DeepEqual(&got, c.inView) { - t.Errorf("Test \"%s\" failed when calling Clone(%v) on %v. Got %v. Want %v", - c.comment, c.inBuffer, c.inView, got, c.inView) - } + t.Run(c.comment, func(t *testing.T) { + got := c.inView.Clone(c.inBuffer) + if !reflect.DeepEqual(got, c.inView) { + t.Fatalf("got (%+v).Clone(%+v) = %+v, want = %+v", + c.inView, c.inBuffer, got, c.inView) + } + }) } } diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 8e0e49efa..206531f20 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -39,40 +39,52 @@ type TransportChecker func(*testing.T, header.Transport) // // checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y)) func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) { + t.Helper() + ipv4 := header.IPv4(b) if !ipv4.IsValid(len(b)) { - t.Fatalf("Not a valid IPv4 packet") + t.Error("Not a valid IPv4 packet") } xsum := ipv4.CalculateChecksum() if xsum != 0 && xsum != 0xffff { - t.Fatalf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum()) + t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum()) } for _, f := range checkers { f(t, []header.Network{ipv4}) } + if t.Failed() { + t.FailNow() + } } // IPv6 checks the validity and properties of the given IPv6 packet. The usage // is similar to IPv4. func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) { + t.Helper() + ipv6 := header.IPv6(b) if !ipv6.IsValid(len(b)) { - t.Fatalf("Not a valid IPv6 packet") + t.Error("Not a valid IPv6 packet") } for _, f := range checkers { f(t, []header.Network{ipv6}) } + if t.Failed() { + t.FailNow() + } } // SrcAddr creates a checker that checks the source address. func SrcAddr(addr tcpip.Address) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + if a := h[0].SourceAddress(); a != addr { - t.Fatalf("Bad source address, got %v, want %v", a, addr) + t.Errorf("Bad source address, got %v, want %v", a, addr) } } } @@ -80,8 +92,10 @@ func SrcAddr(addr tcpip.Address) NetworkChecker { // DstAddr creates a checker that checks the destination address. func DstAddr(addr tcpip.Address) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + if a := h[0].DestinationAddress(); a != addr { - t.Fatalf("Bad destination address, got %v, want %v", a, addr) + t.Errorf("Bad destination address, got %v, want %v", a, addr) } } } @@ -105,8 +119,10 @@ func TTL(ttl uint8) NetworkChecker { // PayloadLen creates a checker that checks the payload length. func PayloadLen(plen int) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + if l := len(h[0].Payload()); l != plen { - t.Fatalf("Bad payload length, got %v, want %v", l, plen) + t.Errorf("Bad payload length, got %v, want %v", l, plen) } } } @@ -114,11 +130,13 @@ func PayloadLen(plen int) NetworkChecker { // FragmentOffset creates a checker that checks the FragmentOffset field. func FragmentOffset(offset uint16) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + // We only do this of IPv4 for now. switch ip := h[0].(type) { case header.IPv4: if v := ip.FragmentOffset(); v != offset { - t.Fatalf("Bad fragment offset, got %v, want %v", v, offset) + t.Errorf("Bad fragment offset, got %v, want %v", v, offset) } } } @@ -127,11 +145,13 @@ func FragmentOffset(offset uint16) NetworkChecker { // FragmentFlags creates a checker that checks the fragment flags field. func FragmentFlags(flags uint8) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + // We only do this of IPv4 for now. switch ip := h[0].(type) { case header.IPv4: if v := ip.Flags(); v != flags { - t.Fatalf("Bad fragment offset, got %v, want %v", v, flags) + t.Errorf("Bad fragment offset, got %v, want %v", v, flags) } } } @@ -140,8 +160,10 @@ func FragmentFlags(flags uint8) NetworkChecker { // TOS creates a checker that checks the TOS field. func TOS(tos uint8, label uint32) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + if v, l := h[0].TOS(); v != tos || l != label { - t.Fatalf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label) + t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label) } } } @@ -153,8 +175,10 @@ func TOS(tos uint8, label uint32) NetworkChecker { // the bytes added by the IPv6 fragmentation. func Raw(want []byte) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) { - t.Fatalf("Wrong payload, got %v, want %v", got, want) + t.Errorf("Wrong payload, got %v, want %v", got, want) } } } @@ -162,18 +186,23 @@ func Raw(want []byte) NetworkChecker { // IPv6Fragment creates a checker that validates an IPv6 fragment. func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader { - t.Fatalf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) + t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) } ipv6Frag := header.IPv6Fragment(h[0].Payload()) if !ipv6Frag.IsValid() { - t.Fatalf("Not a valid IPv6 fragment") + t.Error("Not a valid IPv6 fragment") } for _, f := range checkers { f(t, []header.Network{h[0], ipv6Frag}) } + if t.Failed() { + t.FailNow() + } } } @@ -181,11 +210,13 @@ func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker { // potentially additional transport header fields. func TCP(checkers ...TransportChecker) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + first := h[0] last := h[len(h)-1] if p := last.TransportProtocol(); p != header.TCPProtocolNumber { - t.Fatalf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber) + t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber) } // Verify the checksum. @@ -199,13 +230,16 @@ func TCP(checkers ...TransportChecker) NetworkChecker { xsum = header.Checksum(tcp, xsum) if xsum != 0 && xsum != 0xffff { - t.Fatalf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum()) + t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum()) } // Run the transport checkers. for _, f := range checkers { f(t, tcp) } + if t.Failed() { + t.FailNow() + } } } @@ -213,24 +247,31 @@ func TCP(checkers ...TransportChecker) NetworkChecker { // potentially additional transport header fields. func UDP(checkers ...TransportChecker) NetworkChecker { return func(t *testing.T, h []header.Network) { + t.Helper() + last := h[len(h)-1] if p := last.TransportProtocol(); p != header.UDPProtocolNumber { - t.Fatalf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) + t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) } udp := header.UDP(last.Payload()) for _, f := range checkers { f(t, udp) } + if t.Failed() { + t.FailNow() + } } } // SrcPort creates a checker that checks the source port. func SrcPort(port uint16) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + if p := h.SourcePort(); p != port { - t.Fatalf("Bad source port, got %v, want %v", p, port) + t.Errorf("Bad source port, got %v, want %v", p, port) } } } @@ -239,7 +280,7 @@ func SrcPort(port uint16) TransportChecker { func DstPort(port uint16) TransportChecker { return func(t *testing.T, h header.Transport) { if p := h.DestinationPort(); p != port { - t.Fatalf("Bad destination port, got %v, want %v", p, port) + t.Errorf("Bad destination port, got %v, want %v", p, port) } } } @@ -247,13 +288,15 @@ func DstPort(port uint16) TransportChecker { // SeqNum creates a checker that checks the sequence number. func SeqNum(seq uint32) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + tcp, ok := h.(header.TCP) if !ok { return } if s := tcp.SequenceNumber(); s != seq { - t.Fatalf("Bad sequence number, got %v, want %v", s, seq) + t.Errorf("Bad sequence number, got %v, want %v", s, seq) } } } @@ -268,7 +311,7 @@ func AckNum(seq uint32) TransportChecker { } if s := tcp.AckNumber(); s != seq { - t.Fatalf("Bad ack number, got %v, want %v", s, seq) + t.Errorf("Bad ack number, got %v, want %v", s, seq) } } } @@ -282,7 +325,7 @@ func Window(window uint16) TransportChecker { } if w := tcp.WindowSize(); w != window { - t.Fatalf("Bad window, got 0x%x, want 0x%x", w, window) + t.Errorf("Bad window, got 0x%x, want 0x%x", w, window) } } } @@ -290,13 +333,15 @@ func Window(window uint16) TransportChecker { // TCPFlags creates a checker that checks the tcp flags. func TCPFlags(flags uint8) TransportChecker { return func(t *testing.T, h header.Transport) { + t.Helper() + tcp, ok := h.(header.TCP) if !ok { return } if f := tcp.Flags(); f != flags { - t.Fatalf("Bad flags, got 0x%x, want 0x%x", f, flags) + t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags) } } } @@ -311,7 +356,7 @@ func TCPFlagsMatch(flags, mask uint8) TransportChecker { } if f := tcp.Flags(); (f & mask) != (flags & mask) { - t.Fatalf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask) + t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask) } } } @@ -343,26 +388,26 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { case header.TCPOptionMSS: v := uint16(opts[i+2])<<8 | uint16(opts[i+3]) if wantOpts.MSS != v { - t.Fatalf("Bad MSS: got %v, want %v", v, wantOpts.MSS) + t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS) } foundMSS = true i += 4 case header.TCPOptionWS: if wantOpts.WS < 0 { - t.Fatalf("WS present when it shouldn't be") + t.Error("WS present when it shouldn't be") } v := int(opts[i+2]) if v != wantOpts.WS { - t.Fatalf("Bad WS: got %v, want %v", v, wantOpts.WS) + t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS) } foundWS = true i += 3 case header.TCPOptionTS: if i+9 >= limit { - t.Fatalf("TS Option truncated , option is only: %d bytes, want 10", limit-i) + t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i) } if opts[i+1] != 10 { - t.Fatalf("Bad length %d for TS option, limit: %d", opts[i+1], limit) + t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit) } tsVal = binary.BigEndian.Uint32(opts[i+2:]) tsEcr = uint32(0) @@ -375,10 +420,10 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { i += 10 case header.TCPOptionSACKPermitted: if i+1 >= limit { - t.Fatalf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i) + t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i) } if opts[i+1] != 2 { - t.Fatalf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit) + t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit) } foundSACKPermitted = true i += 2 @@ -389,23 +434,23 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { } if !foundMSS { - t.Fatalf("MSS option not found. Options: %x", opts) + t.Errorf("MSS option not found. Options: %x", opts) } if !foundWS && wantOpts.WS >= 0 { - t.Fatalf("WS option not found. Options: %x", opts) + t.Errorf("WS option not found. Options: %x", opts) } if wantOpts.TS && !foundTS { - t.Fatalf("TS option not found. Options: %x", opts) + t.Errorf("TS option not found. Options: %x", opts) } if foundTS && tsVal == 0 { - t.Fatalf("TS option specified but the timestamp value is zero") + t.Error("TS option specified but the timestamp value is zero") } if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 { - t.Fatalf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr) + t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr) } if wantOpts.SACKPermitted && !foundSACKPermitted { - t.Fatalf("SACKPermitted option not found. Options: %x", opts) + t.Errorf("SACKPermitted option not found. Options: %x", opts) } } } @@ -435,10 +480,10 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp i++ case header.TCPOptionTS: if i+9 >= limit { - t.Fatalf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i) + t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i) } if opts[i+1] != 10 { - t.Fatalf("TS option found, but bad length specified: %d, want: 10", opts[i+1]) + t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1]) } tsVal = binary.BigEndian.Uint32(opts[i+2:]) tsEcr = binary.BigEndian.Uint32(opts[i+6:]) @@ -458,13 +503,13 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp } if wantTS != foundTS { - t.Fatalf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS) + t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS) } if wantTS && wantTSVal != 0 && wantTSVal != tsVal { - t.Fatalf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal) + t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal) } if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr { - t.Fatalf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr) + t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr) } } } @@ -497,12 +542,12 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { case header.TCPOptionSACK: if i+2 > limit { // Malformed SACK block. - t.Fatalf("malformed SACK option in options: %v", opts) + t.Errorf("malformed SACK option in options: %v", opts) } sackOptionLen := int(opts[i+1]) if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 { // Malformed SACK block. - t.Fatalf("malformed SACK option length in options: %v", opts) + t.Errorf("malformed SACK option length in options: %v", opts) } numBlocks := sackOptionLen / 8 for j := 0; j < numBlocks; j++ { @@ -528,7 +573,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { } if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) { - t.Fatalf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks) + t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks) } } } @@ -537,7 +582,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { func Payload(want []byte) TransportChecker { return func(t *testing.T, h header.Transport) { if got := h.Payload(); !reflect.DeepEqual(got, want) { - t.Fatalf("Wrong payload, got %v, want %v", got, want) + t.Errorf("Wrong payload, got %v, want %v", got, want) } } } diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 6983fae3f..a927a1b3f 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -66,15 +66,13 @@ func (e *Endpoint) Drain() int { } // Inject injects an inbound packet. -func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) { - uu := vv.Clone(nil) - e.dispatcher.DeliverNetworkPacket(e, "", protocol, &uu) +func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { + e.InjectLinkAddr(protocol, "", vv) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, vv *buffer.VectorisedView) { - uu := vv.Clone(nil) - e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, protocol, &uu) +func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, vv buffer.VectorisedView) { + e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, protocol, vv.Clone(nil)) } // Attach saves the stack network-layer dispatcher for use later when packets diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 12c249c0d..0b985928b 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -57,7 +57,6 @@ type endpoint struct { // its end of the communication pipe. closed func(*tcpip.Error) - vv *buffer.VectorisedView iovecs []syscall.Iovec views []buffer.View dispatcher stack.NetworkDispatcher @@ -118,8 +117,6 @@ func New(opts *Options) tcpip.LinkEndpointID { iovecs: make([]syscall.Iovec, len(BufConfig)), handleLocal: opts.HandleLocal, } - vv := buffer.NewVectorisedView(0, e.views) - e.vv = &vv return stack.RegisterLinkEndpoint(e) } @@ -167,7 +164,7 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload views[0] = hdr.View() views = append(views, payload.Views()...) vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) - e.dispatcher.DeliverNetworkPacket(e, r.RemoteLinkAddress, protocol, &vv) + e.dispatcher.DeliverNetworkPacket(e, r.RemoteLinkAddress, protocol, vv) return nil } if e.hdrSize > 0 { @@ -246,11 +243,10 @@ func (e *endpoint) dispatch(largeV buffer.View) (bool, *tcpip.Error) { } used := e.capViews(n, BufConfig) - e.vv.SetViews(e.views[:used]) - e.vv.SetSize(n) - e.vv.TrimFront(e.hdrSize) + vv := buffer.NewVectorisedView(n, e.views[:used]) + vv.TrimFront(e.hdrSize) - e.dispatcher.DeliverNetworkPacket(e, addr, p, e.vv) + e.dispatcher.DeliverNetworkPacket(e, addr, p, vv) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { @@ -290,7 +286,7 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) { } // Inject injects an inbound packet. -func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) { +func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { e.dispatcher.DeliverNetworkPacket(e, "", protocol, vv) } diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 408169bbe..21d2f10b0 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -77,7 +77,7 @@ func (c *context) cleanup() { syscall.Close(c.fds[1]) } -func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) { +func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { c.ch <- packetInfo{remoteLinkAddr, protocol, vv.ToView()} } @@ -158,8 +158,7 @@ func TestWritePacket(t *testing.T) { payload[i] = uint8(rand.Intn(256)) } want := append(hdr.UsedBytes(), payload...) - vv := buffer.NewVectorisedView(len(payload), []buffer.View{payload}) - if err := c.ep.WritePacket(r, &hdr, vv, proto); err != nil { + if err := c.ep.WritePacket(r, &hdr, payload.ToVectorisedView(), proto); err != nil { t.Fatalf("WritePacket failed: %v", err) } diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 4a750fa12..884de83c9 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -77,7 +77,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, hdr *buffer.Prependable, payload views[0] = hdr.View() views = append(views, payload.Views()...) vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) - e.dispatcher.DeliverNetworkPacket(e, "", protocol, &vv) + e.dispatcher.DeliverNetworkPacket(e, "", protocol, vv) return nil } diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 6bd5441f6..0dd23794b 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -227,8 +227,6 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { // Read in a loop until a stop is requested. var rxb []queue.RxBuffer - views := []buffer.View{nil} - vv := buffer.NewVectorisedView(0, views) for atomic.LoadUint32(&e.stopRequested) == 0 { var n uint32 rxb, n = e.rx.postAndReceive(rxb, &e.stopRequested) @@ -250,9 +248,7 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { // Send packet up the stack. eth := header.Ethernet(b) - views[0] = b[header.EthernetMinimumSize:] - vv.SetSize(int(n) - header.EthernetMinimumSize) - d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.Type(), &vv) + d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.Type(), buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView()) } // Clean state. diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 69d4ef29f..682c38400 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -129,7 +129,7 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress return c } -func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) { +func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { c.mu.Lock() c.packets = append(c.packets, packetInfo{ addr: remoteAddr, @@ -270,8 +270,7 @@ func TestSimpleSend(t *testing.T) { randomFill(buf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - vv := buffer.NewVectorisedView(len(buf), []buffer.View{buf}) - if err := c.ep.WritePacket(&r, &hdr, vv, proto); err != nil { + if err := c.ep.WritePacket(&r, &hdr, buf.ToVectorisedView(), proto); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -330,7 +329,6 @@ func TestFillTxQueue(t *testing.T) { } buf := buffer.NewView(100) - vv := buffer.NewVectorisedView(len(buf), []buffer.View{buf}) // Each packet is uses no more than 40 bytes, so write that many packets // until the tx queue if full. @@ -338,7 +336,7 @@ func TestFillTxQueue(t *testing.T) { for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, &hdr, vv, header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, &hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -353,7 +351,7 @@ func TestFillTxQueue(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, &hdr, vv, header.IPv4ProtocolNumber); err != want { + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, &hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -374,12 +372,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { } buf := buffer.NewView(100) - vv := buffer.NewVectorisedView(len(buf), []buffer.View{buf}) // Send two packets so that the id slice has at least two slots. for i := 2; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, &hdr, vv, header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, &hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } @@ -399,7 +396,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, &hdr, vv, header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, &hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -414,7 +411,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, &hdr, vv, header.IPv4ProtocolNumber); err != want { + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, &hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -431,14 +428,13 @@ func TestFillTxMemory(t *testing.T) { } buf := buffer.NewView(100) - vv := buffer.NewVectorisedView(len(buf), []buffer.View{buf}) // Each packet is uses up one buffer, so write as many as possible until // we fill the memory. ids := make(map[uint64]struct{}) for i := queueDataSize / bufferSize; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, &hdr, vv, header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, &hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -454,7 +450,7 @@ func TestFillTxMemory(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, &hdr, vv, header.IPv4ProtocolNumber) + err := c.ep.WritePacket(&r, &hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber) if want := tcpip.ErrWouldBlock; err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } @@ -474,13 +470,12 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { } buf := buffer.NewView(100) - vv := buffer.NewVectorisedView(len(buf), []buffer.View{buf}) // Each packet is uses up one buffer, so write as many as possible // until there is only one buffer left. for i := queueDataSize/bufferSize - 1; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, &hdr, vv, header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, &hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -490,20 +485,26 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { } // Attempt to write a two-buffer packet. It must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - uu := buffer.NewVectorisedView(bufferSize, []buffer.View{buffer.NewView(bufferSize)}) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, &hdr, uu, header.IPv4ProtocolNumber); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + { + hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) + uu := buffer.NewView(bufferSize).ToVectorisedView() + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, &hdr, uu, header.IPv4ProtocolNumber); err != want { + t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + } } // Attempt to write the one-buffer packet again. It must succeed. - hdr = buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, &hdr, vv, header.IPv4ProtocolNumber); err != nil { - t.Fatalf("WritePacket failed unexpectedly: %v", err) + { + hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) + if err := c.ep.WritePacket(&r, &hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + t.Fatalf("WritePacket failed unexpectedly: %v", err) + } } } func pollPull(t *testing.T, p *pipe.Rx, to <-chan time.Time, errStr string) []byte { + t.Helper() + for { b := p.Pull() if b != nil { @@ -513,7 +514,7 @@ func pollPull(t *testing.T, p *pipe.Rx, to <-chan time.Time, errStr string) []by select { case <-time.After(10 * time.Millisecond): case <-to: - t.Fatalf(errStr) + t.Fatal(errStr) } } } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 3bdc85210..5a70e062f 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -116,7 +116,7 @@ func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcp // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. -func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) { +func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { logPacket("recv", protocol, vv.First()) } diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 1c19a4509..cc1717ac7 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -51,7 +51,7 @@ func New(lower tcpip.LinkEndpointID) (tcpip.LinkEndpointID, *Endpoint) { // It is called by the link-layer endpoint being wrapped when a packet arrives, // and only forwards to the actual dispatcher if Wait or WaitDispatch haven't // been called. -func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) { +func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { if !e.dispatchGate.Enter() { return } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 0719a95a9..f20ee2fcb 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -35,7 +35,7 @@ type countedEndpoint struct { dispatcher stack.NetworkDispatcher } -func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) { +func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { e.dispatchCount++ } @@ -106,21 +106,21 @@ func TestWaitDispatch(t *testing.T) { } // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket(ep, "", 0, nil) + ep.dispatcher.DeliverNetworkPacket(ep, "", 0, buffer.VectorisedView{}) if want := 1; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on writes, then try to dispatch. It must go through. wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket(ep, "", 0, nil) + ep.dispatcher.DeliverNetworkPacket(ep, "", 0, buffer.VectorisedView{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on dispatches, then try to dispatch. It must not go through. wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket(ep, "", 0, nil) + ep.dispatcher.DeliverNetworkPacket(ep, "", 0, buffer.VectorisedView{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 8f64e3f42..3f63daadd 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -83,7 +83,7 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { v := vv.First() h := header.ARP(v) if !h.IsValid() { diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 8fc79dc94..50628e4a2 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -96,54 +96,54 @@ func TestDirectRequest(t *testing.T) { copy(h.HardwareAddressSender(), senderMAC) copy(h.ProtocolAddressSender(), senderIPv4) - // stackAddr1 - copy(h.ProtocolAddressTarget(), stackAddr1) - vv := v.ToVectorisedView([1]buffer.View{}) - c.linkEP.Inject(arp.ProtocolNumber, &vv) - pkt := <-c.linkEP.C - if pkt.Proto != arp.ProtocolNumber { - t.Fatalf("stackAddr1: expected ARP response, got network protocol number %v", pkt.Proto) - } - rep := header.ARP(pkt.Header) - if !rep.IsValid() { - t.Fatalf("stackAddr1: invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) - } - if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr1 { - t.Errorf("stackAddr1: expected sender to be set") - } - if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { - t.Errorf("stackAddr1: expected sender to be stackLinkAddr, got %q", got) + inject := func(addr tcpip.Address) { + copy(h.ProtocolAddressTarget(), addr) + c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView()) } - // stackAddr2 - copy(h.ProtocolAddressTarget(), stackAddr2) - vv = v.ToVectorisedView([1]buffer.View{}) - c.linkEP.Inject(arp.ProtocolNumber, &vv) - pkt = <-c.linkEP.C - if pkt.Proto != arp.ProtocolNumber { - t.Fatalf("stackAddr2: expected ARP response, got network protocol number %v", pkt.Proto) + inject(stackAddr1) + { + pkt := <-c.linkEP.C + if pkt.Proto != arp.ProtocolNumber { + t.Fatalf("stackAddr1: expected ARP response, got network protocol number %v", pkt.Proto) + } + rep := header.ARP(pkt.Header) + if !rep.IsValid() { + t.Fatalf("stackAddr1: invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) + } + if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr1 { + t.Errorf("stackAddr1: expected sender to be set") + } + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { + t.Errorf("stackAddr1: expected sender to be stackLinkAddr, got %q", got) + } } - rep = header.ARP(pkt.Header) - if !rep.IsValid() { - t.Fatalf("stackAddr2: invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) - } - if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr2 { - t.Errorf("stackAddr2: expected sender to be set") - } - if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { - t.Errorf("stackAddr2: expected sender to be stackLinkAddr, got %q", got) + + inject(stackAddr2) + { + pkt := <-c.linkEP.C + if pkt.Proto != arp.ProtocolNumber { + t.Fatalf("stackAddr2: expected ARP response, got network protocol number %v", pkt.Proto) + } + rep := header.ARP(pkt.Header) + if !rep.IsValid() { + t.Fatalf("stackAddr2: invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) + } + if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr2 { + t.Errorf("stackAddr2: expected sender to be set") + } + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { + t.Errorf("stackAddr2: expected sender to be stackLinkAddr, got %q", got) + } } - // stackAddrBad - copy(h.ProtocolAddressTarget(), stackAddrBad) - vv = v.ToVectorisedView([1]buffer.View{}) - c.linkEP.Inject(arp.ProtocolNumber, &vv) + inject(stackAddrBad) select { case pkt := <-c.linkEP.C: t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) case <-time.After(100 * time.Millisecond): - // Sleep tests are gross, but this will only - // potentially fail flakily if there's a bugj - // If there is no bug this will reliably succeed. + // Sleep tests are gross, but this will only potentially flake + // if there's a bug. If there is no bug this will reliably + // succeed. } } diff --git a/pkg/tcpip/network/fragmentation/frag_heap.go b/pkg/tcpip/network/fragmentation/frag_heap.go index 073882e99..6c7faafe4 100644 --- a/pkg/tcpip/network/fragmentation/frag_heap.go +++ b/pkg/tcpip/network/fragmentation/frag_heap.go @@ -23,7 +23,7 @@ import ( type fragment struct { offset uint16 - vv *buffer.VectorisedView + vv buffer.VectorisedView } type fragHeap []fragment @@ -60,7 +60,7 @@ func (h *fragHeap) reassemble() (buffer.VectorisedView, error) { size := curr.vv.Size() if curr.offset != 0 { - return buffer.NewVectorisedView(0, nil), fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset) + return buffer.VectorisedView{}, fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset) } for h.Len() > 0 { @@ -68,7 +68,7 @@ func (h *fragHeap) reassemble() (buffer.VectorisedView, error) { if int(curr.offset) < size { curr.vv.TrimFront(size - int(curr.offset)) } else if int(curr.offset) > size { - return buffer.NewVectorisedView(0, nil), fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset) + return buffer.VectorisedView{}, fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset) } size += curr.vv.Size() views = append(views, curr.vv.Views()...) diff --git a/pkg/tcpip/network/fragmentation/frag_heap_test.go b/pkg/tcpip/network/fragmentation/frag_heap_test.go index a2fe80264..a15540634 100644 --- a/pkg/tcpip/network/fragmentation/frag_heap_test.go +++ b/pkg/tcpip/network/fragmentation/frag_heap_test.go @@ -25,7 +25,7 @@ import ( var reassambleTestCases = []struct { comment string in []fragment - want *buffer.VectorisedView + want buffer.VectorisedView }{ { comment: "Non-overlapping in-order", @@ -87,21 +87,25 @@ var reassambleTestCases = []struct { func TestReassamble(t *testing.T) { for _, c := range reassambleTestCases { - h := (fragHeap)(make([]fragment, 0, 8)) - heap.Init(&h) - for _, f := range c.in { - heap.Push(&h, f) - } - got, _ := h.reassemble() - - if !reflect.DeepEqual(got, *c.want) { - t.Errorf("Test \"%s\" reassembling failed. Got %v. Want %v", c.comment, got, *c.want) - } + t.Run(c.comment, func(t *testing.T) { + h := make(fragHeap, 0, 8) + heap.Init(&h) + for _, f := range c.in { + heap.Push(&h, f) + } + got, err := h.reassemble() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, c.want) { + t.Errorf("got reassemble(%+v) = %v, want = %v", c.in, got, c.want) + } + }) } } func TestReassambleFailsForNonZeroOffset(t *testing.T) { - h := (fragHeap)(make([]fragment, 0, 8)) + h := make(fragHeap, 0, 8) heap.Init(&h) heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")}) _, err := h.reassemble() @@ -111,7 +115,7 @@ func TestReassambleFailsForNonZeroOffset(t *testing.T) { } func TestReassambleFailsForHoles(t *testing.T) { - h := (fragHeap)(make([]fragment, 0, 8)) + h := make(fragHeap, 0, 8) heap.Init(&h) heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")}) heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")}) diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 21497f876..885e3cca2 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -82,7 +82,7 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t // Process processes an incoming fragment beloning to an ID // and returns a complete packet when all the packets belonging to that ID have been received. -func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv *buffer.VectorisedView) (buffer.VectorisedView, bool) { +func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool) { f.mu.Lock() r, ok := f.reassemblers[id] if ok && r.tooOld(f.timeout) { diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index 7320e594f..fc62a15dd 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -23,19 +23,13 @@ import ( ) // vv is a helper to build VectorisedView from different strings. -func vv(size int, pieces ...string) *buffer.VectorisedView { +func vv(size int, pieces ...string) buffer.VectorisedView { views := make([]buffer.View, len(pieces)) for i, p := range pieces { views[i] = []byte(p) } - vv := buffer.NewVectorisedView(size, views) - return &vv -} - -func emptyVv() *buffer.VectorisedView { - vv := buffer.NewVectorisedView(0, nil) - return &vv + return buffer.NewVectorisedView(size, views) } type processInput struct { @@ -43,11 +37,11 @@ type processInput struct { first uint16 last uint16 more bool - vv *buffer.VectorisedView + vv buffer.VectorisedView } type processOutput struct { - vv *buffer.VectorisedView + vv buffer.VectorisedView done bool } @@ -63,7 +57,7 @@ var processTestCases = []struct { {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")}, }, out: []processOutput{ - {vv: emptyVv(), done: false}, + {vv: buffer.VectorisedView{}, done: false}, {vv: vv(4, "01", "23"), done: true}, }, }, @@ -76,8 +70,8 @@ var processTestCases = []struct { {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")}, }, out: []processOutput{ - {vv: emptyVv(), done: false}, - {vv: emptyVv(), done: false}, + {vv: buffer.VectorisedView{}, done: false}, + {vv: buffer.VectorisedView{}, done: false}, {vv: vv(4, "ab", "cd"), done: true}, {vv: vv(4, "01", "23"), done: true}, }, @@ -86,26 +80,28 @@ var processTestCases = []struct { func TestFragmentationProcess(t *testing.T) { for _, c := range processTestCases { - f := NewFragmentation(1024, 512, DefaultReassembleTimeout) - for i, in := range c.in { - vv, done := f.Process(in.id, in.first, in.last, in.more, in.vv) - if !reflect.DeepEqual(vv, *(c.out[i].vv)) { - t.Errorf("Test \"%s\" Process() returned a wrong vv. Got %v. Want %v", c.comment, vv, *(c.out[i].vv)) - } - if done != c.out[i].done { - t.Errorf("Test \"%s\" Process() returned a wrong done. Got %t. Want %t", c.comment, done, c.out[i].done) - } - if c.out[i].done { - if _, ok := f.reassemblers[in.id]; ok { - t.Errorf("Test \"%s\" Process() didn't remove buffer from reassemblers.", c.comment) + t.Run(c.comment, func(t *testing.T) { + f := NewFragmentation(1024, 512, DefaultReassembleTimeout) + for i, in := range c.in { + vv, done := f.Process(in.id, in.first, in.last, in.more, in.vv) + if !reflect.DeepEqual(vv, c.out[i].vv) { + t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv) + } + if done != c.out[i].done { + t.Errorf("got Process(%d) = %+v, want = %+v", i, done, c.out[i].done) } - for n := f.rList.Front(); n != nil; n = n.Next() { - if n.id == in.id { - t.Errorf("Test \"%s\" Process() didn't remove buffer from rList.", c.comment) + if c.out[i].done { + if _, ok := f.reassemblers[in.id]; ok { + t.Errorf("Process(%d) did not remove buffer from reassemblers", i) + } + for n := f.rList.Front(); n != nil; n = n.Next() { + if n.id == in.id { + t.Errorf("Process(%d) did not remove buffer from rList", i) + } } } } - } + }) } } @@ -161,16 +157,3 @@ func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) } } - -func TestFragmentationViewsDoNotEscape(t *testing.T) { - f := NewFragmentation(1024, 512, DefaultReassembleTimeout) - in := vv(2, "0", "1") - f.Process(0, 0, 1, true, in) - // Modify input view. - in.RemoveFirst() - got, _ := f.Process(0, 2, 2, false, vv(1, "2")) - want := vv(3, "0", "1", "2") - if !reflect.DeepEqual(got, *want) { - t.Errorf("Process() returned a wrong vv. Got %v. Want %v", got, *want) - } -} diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 7c465c1ac..b57fe82ec 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -78,7 +78,7 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) bool { return used } -func (r *reassembler) process(first, last uint16, more bool, vv *buffer.VectorisedView) (buffer.VectorisedView, bool, int) { +func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int) { r.mu.Lock() defer r.mu.Unlock() consumed := 0 @@ -86,18 +86,17 @@ func (r *reassembler) process(first, last uint16, more bool, vv *buffer.Vectoris // A concurrent goroutine might have already reassembled // the packet and emptied the heap while this goroutine // was waiting on the mutex. We don't have to do anything in this case. - return buffer.NewVectorisedView(0, nil), false, consumed + return buffer.VectorisedView{}, false, consumed } if r.updateHoles(first, last, more) { // We store the incoming packet only if it filled some holes. - uu := vv.Clone(nil) - heap.Push(&r.heap, fragment{offset: first, vv: &uu}) + heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)}) consumed = vv.Size() r.size += consumed } // Check if all the holes have been deleted and we are ready to reassamble. if r.deleted < len(r.holes) { - return buffer.NewVectorisedView(0, nil), false, consumed + return buffer.VectorisedView{}, false, consumed } res, err := r.heap.reassemble() if err != nil { diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 0cb53fb42..fe6bf0441 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -94,16 +94,16 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView) { - t.checkValues(protocol, *vv, r.RemoteAddress, r.LocalAddress) +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) { + t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress) t.dataCalls++ } // DeliverTransportControlPacket is called by network endpoints after parsing // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) { - t.checkValues(trans, *vv, remote, local) +func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { + t.checkValues(trans, vv, remote, local) if typ != t.typ { t.t.Errorf("typ = %v, want %v", typ, t.typ) } @@ -221,8 +221,7 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - vv := buffer.NewVectorisedView(len(payload), []buffer.View{payload}) - if err := ep.WritePacket(&r, &hdr, vv, 123, 123); err != nil { + if err := ep.WritePacket(&r, &hdr, payload.ToVectorisedView(), 123, 123); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -262,9 +261,7 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - var views [1]buffer.View - vv := view.ToVectorisedView(views) - ep.HandlePacket(&r, &vv) + ep.HandlePacket(&r, view.ToVectorisedView()) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -296,7 +293,6 @@ func TestIPv4ReceiveControl(t *testing.T) { } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - var views [1]buffer.View o := testObject{t: t} proto := ipv4.NewProtocol() ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil) @@ -351,9 +347,8 @@ func TestIPv4ReceiveControl(t *testing.T) { o.typ = c.expectedTyp o.extra = c.expectedExtra - vv := view.ToVectorisedView(views) - vv.CapLength(len(view) - c.trunc) - ep.HandlePacket(&r, &vv) + vv := view[:len(view)-c.trunc].ToVectorisedView() + ep.HandlePacket(&r, vv) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } @@ -416,18 +411,13 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - var views1 [1]buffer.View - vv1 := frag1.ToVectorisedView(views1) - ep.HandlePacket(&r, &vv1) + ep.HandlePacket(&r, frag1.ToVectorisedView()) if o.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) } // Send second segment. - var views2 [1]buffer.View - vv2 := frag2.ToVectorisedView(views2) - ep.HandlePacket(&r, &vv2) - + ep.HandlePacket(&r, frag2.ToVectorisedView()) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -460,8 +450,7 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - vv := buffer.NewVectorisedView(len(payload), []buffer.View{payload}) - if err := ep.WritePacket(&r, &hdr, vv, 123, 123); err != nil { + if err := ep.WritePacket(&r, &hdr, payload.ToVectorisedView(), 123, 123); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -501,10 +490,7 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - var views [1]buffer.View - vv := view.ToVectorisedView(views) - ep.HandlePacket(&r, &vv) - + ep.HandlePacket(&r, view.ToVectorisedView()) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -541,7 +527,6 @@ func TestIPv6ReceiveControl(t *testing.T) { } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - var views [1]buffer.View o := testObject{t: t} proto := ipv6.NewProtocol() ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, &o, nil) @@ -609,9 +594,8 @@ func TestIPv6ReceiveControl(t *testing.T) { o.typ = c.expectedTyp o.extra = c.expectedExtra - vv := view.ToVectorisedView(views) - vv.CapLength(len(view) - c.trunc) - ep.HandlePacket(&r, &vv) + vv := view[:len(view)-c.trunc].ToVectorisedView() + ep.HandlePacket(&r, vv) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 74454f605..ab2fe8440 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -27,7 +27,7 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) { +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { h := header.IPv4(vv.First()) // We don't use IsValid() here because ICMP only requires that the IP @@ -55,7 +55,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv *buffer e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) } -func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) { v := vv.First() if len(v) < header.ICMPv4MinimumSize { return @@ -120,6 +120,5 @@ func sendPing4(r *stack.Route, code byte, data buffer.View) *tcpip.Error { data = data[header.ICMPv4EchoMinimumSize-header.ICMPv4MinimumSize:] icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) - vv := buffer.NewVectorisedView(len(data), []buffer.View{data}) - return r.WritePacket(&hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL()) + return r.WritePacket(&hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 0a2378a6a..877f34be8 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -133,7 +133,7 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { h := header.IPv4(vv.First()) if !h.IsValid(vv.Size()) { return @@ -148,11 +148,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) { if more || h.FragmentOffset() != 0 { // The packet is a fragment, let's try to reassemble it. last := h.FragmentOffset() + uint16(vv.Size()) - 1 - tt, ready := e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv) + var ready bool + vv, ready = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv) if !ready { return } - vv = &tt } p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 2158ba8f7..c6fcf58d2 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -27,7 +27,7 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) { +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { h := header.IPv6(vv.First()) // We don't use IsValid() here because ICMP only requires that up to @@ -62,7 +62,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv *buffer e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) } -func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) { v := vv.First() if len(v) < header.ICMPv6MinimumSize { return @@ -129,8 +129,8 @@ func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) { pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) copy(pkt, h) pkt.SetType(header.ICMPv6EchoReply) - pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, *vv)) - r.WritePacket(&hdr, *vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()) + pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) + r.WritePacket(&hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()) case header.ICMPv6EchoReply: if len(v) < header.ICMPv6EchoMinimumSize { diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index e9f400fe4..c48859be3 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -165,7 +165,7 @@ func (c *testContext) routePackets(ch <-chan channel.PacketInfo, ep *channel.End views := []buffer.View{pkt.Header, pkt.Payload} size := len(pkt.Header) + len(pkt.Payload) vv := buffer.NewVectorisedView(size, views) - ep.InjectLinkAddr(pkt.Proto, ep.LinkAddress(), &vv) + ep.InjectLinkAddr(pkt.Proto, ep.LinkAddress(), vv) } } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index eb89168c3..8d5ae8303 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -101,7 +101,7 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { h := header.IPv6(vv.First()) if !h.IsValid(vv.Size()) { return diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 61afa673e..4c027e91a 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -391,7 +391,7 @@ func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. -func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) { +func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { netProto, ok := n.stack.networkProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -459,7 +459,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr tcpip.Lin // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -502,7 +502,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // DeliverTransportControlPacket delivers control packets to the appropriate // transport protocol endpoint. -func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv *buffer.VectorisedView) { +func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) { state, ok := n.stack.transportProtocols[trans] if !ok { return diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 7afde3598..acd3fa01b 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -64,11 +64,11 @@ const ( type TransportEndpoint interface { // HandlePacket is called by the stack when new packets arrive to // this transport endpoint. - HandlePacket(r *Route, id TransportEndpointID, vv *buffer.VectorisedView) + HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) // HandleControlPacket is called by the stack when new control (e.g., // ICMP) packets arrive to this transport endpoint. - HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv *buffer.VectorisedView) + HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) } // TransportProtocol is the interface that needs to be implemented by transport @@ -95,7 +95,7 @@ type TransportProtocol interface { // // The return value indicates whether the packet was well-formed (for // stats purposes only). - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, vv *buffer.VectorisedView) bool + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -114,11 +114,11 @@ type TransportProtocol interface { type TransportDispatcher interface { // DeliverTransportPacket delivers packets to the appropriate // transport protocol endpoint. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView) + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. - DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv *buffer.VectorisedView) + DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) } // NetworkEndpoint is the interface that needs to be implemented by endpoints @@ -155,7 +155,7 @@ type NetworkEndpoint interface { // HandlePacket is called by the link layer when new packets arrive to // this network endpoint. - HandlePacket(r *Route, vv *buffer.VectorisedView) + HandlePacket(r *Route, vv buffer.VectorisedView) // Close is called when the endpoint is reomved from a stack. Close() @@ -196,7 +196,7 @@ type NetworkProtocol interface { type NetworkDispatcher interface { // DeliverNetworkPacket finds the appropriate network protocol // endpoint and hands the packet over for further processing. - DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) + DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) } // LinkEndpointCapabilities is the type associated with the capabilities diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 789c819dd..2d313cc27 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -48,7 +48,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(*Route, TransportEndpointID, *buffer.VectorisedView) bool + defaultHandler func(*Route, TransportEndpointID, buffer.VectorisedView) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -428,7 +428,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *buffer.VectorisedView) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.VectorisedView) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 816707d27..279867315 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -73,7 +73,7 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { return &f.id } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) { +func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { // Increment the received packet count in the protocol descriptor. f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ @@ -205,15 +205,12 @@ func TestNetworkReceive(t *testing.T) { } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - var views [1]buffer.View - // Allocate the buffer containing the packet that will be injected into - // the stack. + buf := buffer.NewView(30) // Make sure packet with wrong address is not delivered. buf[0] = 3 - vv := buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber, &vv) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -223,8 +220,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[0] = 1 - vv = buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber, &vv) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -234,8 +230,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[0] = 2 - vv = buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber, &vv) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -244,8 +239,7 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - vv = buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber-1, &vv) + linkEP.Inject(fakeNetNumber-1, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -255,8 +249,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - vv = buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber, &vv) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -460,16 +453,14 @@ func TestAddressRemoval(t *testing.T) { t.Fatalf("AddAddress failed: %v", err) } - var views [1]buffer.View - buf := buffer.NewView(30) - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + buf := buffer.NewView(30) + // Write a packet, and check that it gets delivered. fakeNet.packetCount[1] = 0 buf[0] = 1 - vv := buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber, &vv) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -480,8 +471,7 @@ func TestAddressRemoval(t *testing.T) { t.Fatalf("RemoveAddress failed: %v", err) } - vv = buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber, &vv) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -510,14 +500,12 @@ func TestDelayedRemovalDueToRoute(t *testing.T) { fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - var views [1]buffer.View buf := buffer.NewView(30) // Write a packet, and check that it gets delivered. fakeNet.packetCount[1] = 0 buf[0] = 1 - vv := buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber, &vv) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -528,8 +516,7 @@ func TestDelayedRemovalDueToRoute(t *testing.T) { t.Fatalf("FindRoute failed: %v", err) } - vv = buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber, &vv) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 2 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 2) } @@ -540,8 +527,7 @@ func TestDelayedRemovalDueToRoute(t *testing.T) { t.Fatalf("RemoveAddress failed: %v", err) } - vv = buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber, &vv) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 3 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3) } @@ -553,8 +539,7 @@ func TestDelayedRemovalDueToRoute(t *testing.T) { // Release the route, then check that packet is not deliverable anymore. r.Release() - vv = buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber, &vv) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 3 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3) } @@ -574,15 +559,13 @@ func TestPromiscuousMode(t *testing.T) { fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - var views [1]buffer.View buf := buffer.NewView(30) // Write a packet, and check that it doesn't get delivered as we don't // have a matching endpoint. fakeNet.packetCount[1] = 0 buf[0] = 1 - vv := buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber, &vv) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -592,8 +575,7 @@ func TestPromiscuousMode(t *testing.T) { t.Fatalf("SetPromiscuousMode failed: %v", err) } - vv = buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber, &vv) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -610,8 +592,7 @@ func TestPromiscuousMode(t *testing.T) { t.Fatalf("SetPromiscuousMode failed: %v", err) } - vv = buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber, &vv) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -675,8 +656,8 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) { fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - var views [1]buffer.View buf := buffer.NewView(30) + buf[0] = 1 fakeNet.packetCount[1] = 0 subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) @@ -687,8 +668,7 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) { t.Fatalf("AddSubnet failed: %v", err) } - vv := buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber, &vv) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -709,8 +689,8 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) { fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - var views [1]buffer.View buf := buffer.NewView(30) + buf[0] = 1 fakeNet.packetCount[1] = 0 subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) @@ -720,8 +700,7 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) { if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { t.Fatalf("AddSubnet failed: %v", err) } - vv := buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber, &vv) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 862afa693..a7470d606 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -100,7 +100,7 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN // deliverPacket attempts to deliver the given packet. Returns true if it found // an endpoint, false otherwise. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -127,7 +127,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv *buffer.VectorisedView, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false @@ -149,7 +149,7 @@ func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, return true } -func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv *buffer.VectorisedView, id TransportEndpointID) TransportEndpoint { +func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint { // Try to find a match with the id as provided. if ep := eps.endpoints[id]; ep != nil { return ep diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index e4607192f..226191525 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -70,8 +70,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) if err != nil { return 0, err } - vv := buffer.NewVectorisedView(len(v), []buffer.View{v}) - if err := f.route.WritePacket(&hdr, vv, fakeTransNumber, 123); err != nil { + if err := f.route.WritePacket(&hdr, buffer.View(v).ToVectorisedView(), fakeTransNumber, 123); err != nil { return 0, err } @@ -149,12 +148,12 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandlePacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) { // Increment the number of received packets. f.proto.packetCount++ } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, buffer.VectorisedView) { // Increment the number of received control packets. f.proto.controlCount++ } @@ -193,7 +192,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { return true } @@ -245,15 +244,13 @@ func TestTransportReceive(t *testing.T) { fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol) - var views [1]buffer.View // Create buffer that will hold the packet. buf := buffer.NewView(30) // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - vv := buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber, &vv) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -262,8 +259,7 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - vv = buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber, &vv) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -272,8 +268,7 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - vv = buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber, &vv) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeTrans.packetCount != 1 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) } @@ -305,7 +300,6 @@ func TestTransportControlReceive(t *testing.T) { fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol) - var views [1]buffer.View // Create buffer that will hold the control packet. buf := buffer.NewView(2*fakeNetHeaderLen + 30) @@ -318,8 +312,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - vv := buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber, &vv) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -328,8 +321,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - vv = buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber, &vv) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -338,8 +330,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - vv = buf.ToVectorisedView(views) - linkEP.Inject(fakeNetNumber, &vv) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeTrans.controlCount != 1 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) } diff --git a/pkg/tcpip/transport/ping/endpoint.go b/pkg/tcpip/transport/ping/endpoint.go index 7aaf2d9c6..fea9d6957 100644 --- a/pkg/tcpip/transport/ping/endpoint.go +++ b/pkg/tcpip/transport/ping/endpoint.go @@ -384,8 +384,7 @@ func sendPing4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) - vv := buffer.NewVectorisedView(len(data), []buffer.View{data}) - return r.WritePacket(&hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL()) + return r.WritePacket(&hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) } func sendPing6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { @@ -409,8 +408,7 @@ func sendPing6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { icmpv6.SetChecksum(0) icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0))) - vv := buffer.NewVectorisedView(len(data), []buffer.View{data}) - return r.WritePacket(&hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()) + return r.WritePacket(&hdr, data.ToVectorisedView(), header.ICMPv6ProtocolNumber, r.DefaultTTL()) } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { @@ -675,7 +673,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full. @@ -711,5 +709,5 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { } diff --git a/pkg/tcpip/transport/ping/protocol.go b/pkg/tcpip/transport/ping/protocol.go index b885f3627..549b1b2d3 100644 --- a/pkg/tcpip/transport/ping/protocol.go +++ b/pkg/tcpip/transport/ping/protocol.go @@ -99,7 +99,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { return true } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 60e9daf74..4085585b0 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1323,7 +1323,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { s := newSegment(r, id, vv) if !s.parse() { e.stack.Stats().MalformedRcvdPackets.Increment() @@ -1348,7 +1348,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { switch typ { case stack.ControlPacketTooBig: e.sndBufMu.Lock() diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 8a873db73..c80f3c7d6 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -63,7 +63,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool { s := newSegment(r, id, vv) defer s.decRef() diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index fe21f2c78..abdc825cd 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -120,7 +120,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) bool { +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool { s := newSegment(r, id, vv) defer s.decRef() diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 8dccea2ba..51a3d6aba 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -60,7 +60,7 @@ type segment struct { options []byte `state:".([]byte)"` } -func newSegment(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) *segment { +func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment { s := &segment{ refCnt: 1, id: id, diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index bf26ea24e..871177842 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -2668,11 +2668,11 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.ByteSlice()[0][header.IPv4MinimumSize:] + tcpbuf := vv.First()[header.IPv4MinimumSize:] // 12 is the TCP header data offset. tcpbuf[12] = ((header.TCPMinimumSize - 1) / 4) << 4 - c.SendSegment(&vv) + c.SendSegment(vv) if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want { t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index c46af4b8b..5b25534f4 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -205,9 +205,11 @@ func (c *Context) Stack() *stack.Stack { // CheckNoPacketTimeout verifies that no packet is received during the time // specified by wait. func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) { + c.t.Helper() + select { case <-c.linkEP.C: - c.t.Fatalf(errMsg) + c.t.Fatal(errMsg) case <-time.After(wait): } @@ -290,9 +292,7 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt copy(icmp[header.ICMPv4MinimumSize+len(p1):], p2) // Inject packet. - var views [1]buffer.View - vv := buf.ToVectorisedView(views) - c.linkEP.Inject(ipv4.ProtocolNumber, &vv) + c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) } // BuildSegment builds a TCP segment based on the given Headers and payload. @@ -337,23 +337,19 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView t.SetChecksum(^t.CalculateChecksum(xsum, length)) // Inject packet. - var views [1]buffer.View - vv := buf.ToVectorisedView(views) - - return vv + return buf.ToVectorisedView() } // SendSegment sends a TCP segment that has already been built and written to a // buffer.VectorisedView. -func (c *Context) SendSegment(s *buffer.VectorisedView) { +func (c *Context) SendSegment(s buffer.VectorisedView) { c.linkEP.Inject(ipv4.ProtocolNumber, s) } // SendPacket builds and sends a TCP segment(with the provided payload & TCP // headers) in an IPv4 packet via the link layer endpoint. func (c *Context) SendPacket(payload []byte, h *Headers) { - vv := c.BuildSegment(payload, h) - c.linkEP.Inject(ipv4.ProtocolNumber, &vv) + c.linkEP.Inject(ipv4.ProtocolNumber, c.BuildSegment(payload, h)) } // SendAck sends an ACK packet. @@ -496,9 +492,7 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) { t.SetChecksum(^t.CalculateChecksum(xsum, length)) // Inject packet. - var views [1]buffer.View - vv := buf.ToVectorisedView(views) - c.linkEP.Inject(ipv6.ProtocolNumber, &vv) + c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) } // CreateConnected creates a connected TCP endpoint. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 5de518a55..e9337a88e 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -362,8 +362,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc ttl = e.multicastTTL } - vv := buffer.NewVectorisedView(len(v), []buffer.View{v}) - if err := sendUDP(route, vv, e.id.LocalPort, dstPort, ttl); err != nil { + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil { return 0, err } return uintptr(len(v)), nil @@ -843,7 +842,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { // Get the header then trim it from the view. hdr := header.UDP(vv.First()) if int(hdr.Length()) > vv.Size() { @@ -892,5 +891,5 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index dabc5bd13..1334fec8a 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -62,7 +62,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { return true } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 6d7a737bd..46110c8ff 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -204,9 +204,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers) { u.SetChecksum(^u.CalculateChecksum(xsum, length)) // Inject packet. - var views [1]buffer.View - vv := buf.ToVectorisedView(views) - c.linkEP.Inject(ipv6.ProtocolNumber, &vv) + c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) } func (c *testContext) sendPacket(payload []byte, h *headers) { @@ -245,9 +243,7 @@ func (c *testContext) sendPacket(payload []byte, h *headers) { u.SetChecksum(^u.CalculateChecksum(xsum, length)) // Inject packet. - var views [1]buffer.View - vv := buf.ToVectorisedView(views) - c.linkEP.Inject(ipv4.ProtocolNumber, &vv) + c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) } func newPayload() []byte { |