diff options
Diffstat (limited to 'pkg/tcpip/network/ipv6')
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 58 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp_test.go | 251 |
2 files changed, 242 insertions, 67 deletions
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 36d98caef..3210e6fc7 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -63,15 +63,22 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. } func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { + stats := r.Stats().ICMP + sent := stats.V6PacketsSent + received := stats.V6PacketsReceived v := vv.First() if len(v) < header.ICMPv6MinimumSize { + received.Invalid.Increment() return } h := header.ICMPv6(v) + // TODO: Meaningfully handle all ICMP types. switch h.Type() { case header.ICMPv6PacketTooBig: + received.PacketTooBig.Increment() if len(v) < header.ICMPv6PacketTooBigMinimumSize { + received.Invalid.Increment() return } vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize) @@ -79,7 +86,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) case header.ICMPv6DstUnreachable: + received.DstUnreachable.Increment() if len(v) < header.ICMPv6DstUnreachableMinimumSize { + received.Invalid.Increment() return } vv.TrimFront(header.ICMPv6DstUnreachableMinimumSize) @@ -89,15 +98,21 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V } case header.ICMPv6NeighborSolicit: + received.NeighborSolicit.Increment() + + e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) + if len(v) < header.ICMPv6NeighborSolicitMinimumSize { + received.Invalid.Increment() return } - targetAddr := tcpip.Address(v[8 : 8+16]) + targetAddr := tcpip.Address(v[8:][:16]) if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 { // We don't have a useful answer; the best we can do is ignore the request. return } - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) + + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) pkt.SetType(header.ICMPv6NeighborAdvert) pkt[icmpV6FlagOffset] = ndpSolicitedFlag | ndpOverrideFlag @@ -118,22 +133,29 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V defer r.Release() r.LocalAddress = targetAddr pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL()) - e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) + if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil { + sent.Dropped.Increment() + return + } + sent.NeighborAdvert.Increment() case header.ICMPv6NeighborAdvert: + received.NeighborAdvert.Increment() if len(v) < header.ICMPv6NeighborAdvertSize { + received.Invalid.Increment() return } - targetAddr := tcpip.Address(v[8 : 8+16]) + targetAddr := tcpip.Address(v[8:][:16]) e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress) if targetAddr != r.RemoteAddress { e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) } case header.ICMPv6EchoRequest: + received.EchoRequest.Increment() if len(v) < header.ICMPv6EchoMinimumSize { + received.Invalid.Increment() return } @@ -143,14 +165,37 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V copy(pkt, h) pkt.SetType(header.ICMPv6EchoReply) pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) - r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()) + if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil { + sent.Dropped.Increment() + return + } + sent.EchoReply.Increment() case header.ICMPv6EchoReply: + received.EchoReply.Increment() if len(v) < header.ICMPv6EchoMinimumSize { + received.Invalid.Increment() return } e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, netHeader, vv) + case header.ICMPv6TimeExceeded: + received.TimeExceeded.Increment() + + case header.ICMPv6ParamProblem: + received.ParamProblem.Increment() + + case header.ICMPv6RouterSolicit: + received.RouterSolicit.Increment() + + case header.ICMPv6RouterAdvert: + received.RouterAdvert.Increment() + + case header.ICMPv6RedirectMsg: + received.RedirectMsg.Increment() + + default: + received.Invalid.Increment() } } @@ -202,6 +247,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. DstAddr: r.RemoteAddress, }) + // TODO: count this in ICMP stats. return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index eee09f3af..8b57a0641 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -15,9 +15,10 @@ package ipv6 import ( + "fmt" + "reflect" "strings" "testing" - "time" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" @@ -39,20 +40,151 @@ var ( lladdr1 = header.LinkLocalAddr(linkAddr1) ) -type icmpInfo struct { - typ header.ICMPv6Type - src tcpip.Address +type stubLinkEndpoint struct { + stack.LinkEndpoint +} + +func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { + return 0 +} + +func (*stubLinkEndpoint) MaxHeaderLength() uint16 { + return 0 +} + +func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { + return "" +} + +func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.NetworkProtocolNumber) *tcpip.Error { + return nil +} + +func (*stubLinkEndpoint) Attach(stack.NetworkDispatcher) {} + +type stubDispatcher struct { + stack.TransportDispatcher +} + +func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, buffer.View, buffer.VectorisedView) { +} + +type stubLinkAddressCache struct { + stack.LinkAddressCache +} + +func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.NICID { + return 0 +} + +func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) { +} + +func TestICMPCounts(t *testing.T) { + s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}) + { + id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{}) + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: lladdr1, + Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)), + NIC: 1, + }}, + ) + + ep, err := s.NetworkProtocolInstance(ProtocolNumber).NewEndpoint(0, lladdr1, &stubLinkAddressCache{}, &stubDispatcher{}, nil) + if err != nil { + t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) + } + + r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) + } + defer r.Release() + + types := []struct { + typ header.ICMPv6Type + size int + }{ + {header.ICMPv6DstUnreachable, header.ICMPv6DstUnreachableMinimumSize}, + {header.ICMPv6PacketTooBig, header.ICMPv6PacketTooBigMinimumSize}, + {header.ICMPv6TimeExceeded, header.ICMPv6MinimumSize}, + {header.ICMPv6ParamProblem, header.ICMPv6MinimumSize}, + {header.ICMPv6EchoRequest, header.ICMPv6EchoMinimumSize}, + {header.ICMPv6EchoReply, header.ICMPv6EchoMinimumSize}, + {header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize}, + {header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize}, + {header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize}, + {header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize}, + {header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize}, + } + + handleIPv6Payload := func(hdr buffer.Prependable) { + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: r.DefaultTTL(), + SrcAddr: r.LocalAddress, + DstAddr: r.RemoteAddress, + }) + ep.HandlePacket(&r, hdr.View().ToVectorisedView()) + } + + for _, typ := range types { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size) + pkt := header.ICMPv6(hdr.Prepend(typ.size)) + pkt.SetType(typ.typ) + pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + handleIPv6Payload(hdr) + } + + // Construct an empty ICMP packet so that + // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. + handleIPv6Payload(buffer.NewPrependable(header.IPv6MinimumSize)) + + icmpv6Stats := s.Stats().ICMP.V6PacketsReceived + visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { + if got, want := s.Value(), uint64(1); got != want { + t.Errorf("got %s = %d, want = %d", name, got, want) + } + }) + if t.Failed() { + t.Logf("stats:\n%+v", s.Stats()) + } +} + +func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) { + t := v.Type() + for i := 0; i < v.NumField(); i++ { + v := v.Field(i) + switch v.Kind() { + case reflect.Ptr: + f(t.Field(i).Name, v.Interface().(*tcpip.StatCounter)) + case reflect.Struct: + visitStats(v, f) + default: + panic(fmt.Sprintf("unexpected type %s", v.Type())) + } + } } type testContext struct { - t *testing.T s0 *stack.Stack s1 *stack.Stack linkEP0 *channel.Endpoint linkEP1 *channel.Endpoint - - icmpCh chan icmpInfo } type endpointWithResolutionCapability struct { @@ -65,10 +197,8 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab func newTestContext(t *testing.T) *testContext { c := &testContext{ - t: t, - s0: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}), - s1: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}), - icmpCh: make(chan icmpInfo, 10), + s0: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}), + s1: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}), } const defaultMTU = 65536 @@ -118,50 +248,58 @@ func newTestContext(t *testing.T) *testContext { }}, ) - go c.routePackets(linkEP0.C, linkEP1) - go c.routePackets(linkEP1.C, linkEP0) - return c } -func (c *testContext) countPacket(pkt channel.PacketInfo) { +func (c *testContext) cleanup() { + close(c.linkEP0.C) + close(c.linkEP1.C) +} + +type routeArgs struct { + src, dst *channel.Endpoint + typ header.ICMPv6Type +} + +func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { + t.Helper() + + pkt := <-args.src.C + + { + views := []buffer.View{pkt.Header, pkt.Payload} + size := len(pkt.Header) + len(pkt.Payload) + vv := buffer.NewVectorisedView(size, views) + args.dst.InjectLinkAddr(pkt.Proto, args.dst.LinkAddress(), vv) + } + if pkt.Proto != ProtocolNumber { + t.Errorf("unexpected protocol number %d", pkt.Proto) return } ipv6 := header.IPv6(pkt.Header) transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) if transProto != header.ICMPv6ProtocolNumber { + t.Errorf("unexpected transport protocol number %d", transProto) return } - b := pkt.Header[header.IPv6MinimumSize:] - icmp := header.ICMPv6(b) - c.icmpCh <- icmpInfo{ - typ: icmp.Type(), - src: ipv6.SourceAddress(), + icmpv6 := header.ICMPv6(ipv6.Payload()) + if got, want := icmpv6.Type(), args.typ; got != want { + t.Errorf("got ICMPv6 type = %d, want = %d", got, want) + return } -} - -func (c *testContext) routePackets(ch <-chan channel.PacketInfo, ep *channel.Endpoint) { - for pkt := range ch { - c.countPacket(pkt) - views := []buffer.View{pkt.Header, pkt.Payload} - size := len(pkt.Header) + len(pkt.Payload) - vv := buffer.NewVectorisedView(size, views) - ep.InjectLinkAddr(pkt.Proto, ep.LinkAddress(), vv) + if fn != nil { + fn(t, icmpv6) } } -func (c *testContext) cleanup() { - close(c.linkEP0.C) - close(c.linkEP1.C) -} - func TestLinkResolution(t *testing.T) { c := newTestContext(t) defer c.cleanup() + r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) if err != nil { - t.Fatal(err) + t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) } defer r.Release() @@ -176,14 +314,24 @@ func TestLinkResolution(t *testing.T) { var wq waiter.Queue ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq) if err != nil { - t.Fatal(err) + t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) } for { _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}}) if resCh != nil { if err != tcpip.ErrNoLinkAddress { - t.Fatalf("ep.Write(_) = _, <non-nil>, %s want _, <non-nil>, tcpip.ErrNoLinkAddress", err) + t.Fatalf("ep.Write(_) = _, <non-nil>, %s, want = _, <non-nil>, tcpip.ErrNoLinkAddress", err) + } + for _, args := range []routeArgs{ + {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit}, + {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert}, + } { + routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) { + if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want { + t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want) + } + }) } <-resCh continue @@ -194,29 +342,10 @@ func TestLinkResolution(t *testing.T) { break } - stats := make(map[header.ICMPv6Type]int) - for { - // This actually takes about 10 milliseconds, so no need to wait for - // a multi-minute go test timeout if something is broken. - select { - case <-time.After(2 * time.Second): - t.Errorf("timeout waiting for ICMP, got: %#+v", stats) - return - case icmpInfo := <-c.icmpCh: - switch icmpInfo.typ { - case header.ICMPv6NeighborAdvert: - if got, want := icmpInfo.src, lladdr1; got != want { - t.Errorf("got ICMPv6NeighborAdvert.sourceAddress = %v, want = %v", got, want) - } - } - stats[icmpInfo.typ]++ - - if stats[header.ICMPv6NeighborSolicit] > 0 && - stats[header.ICMPv6NeighborAdvert] > 0 && - stats[header.ICMPv6EchoRequest] > 0 && - stats[header.ICMPv6EchoReply] > 0 { - return - } - } + for _, args := range []routeArgs{ + {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6EchoRequest}, + {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6EchoReply}, + } { + routeICMPv6Packet(t, args, nil) } } |