diff options
author | Googler <noreply@google.com> | 2018-04-27 10:37:02 -0700 |
---|---|---|
committer | Adin Scannell <ascannell@google.com> | 2018-04-28 01:44:26 -0400 |
commit | d02b74a5dcfed4bfc8f2f8e545bca4d2afabb296 (patch) | |
tree | 54f95eef73aee6bacbfc736fffc631be2605ed53 /pkg/tcpip/network/ipv4 | |
parent | f70210e742919f40aa2f0934a22f1c9ba6dada62 (diff) |
Check in gVisor.
PiperOrigin-RevId: 194583126
Change-Id: Ica1d8821a90f74e7e745962d71801c598c652463
Diffstat (limited to 'pkg/tcpip/network/ipv4')
-rw-r--r-- | pkg/tcpip/network/ipv4/BUILD | 38 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp.go | 282 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp_test.go | 124 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 233 |
4 files changed, 677 insertions, 0 deletions
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD new file mode 100644 index 000000000..9df113df1 --- /dev/null +++ b/pkg/tcpip/network/ipv4/BUILD @@ -0,0 +1,38 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "ipv4", + srcs = [ + "icmp.go", + "ipv4.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4", + visibility = [ + "//visibility:public", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/network/fragmentation", + "//pkg/tcpip/network/hash", + "//pkg/tcpip/stack", + "//pkg/waiter", + ], +) + +go_test( + name = "ipv4_test", + size = "small", + srcs = ["icmp_test.go"], + deps = [ + ":ipv4", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go new file mode 100644 index 000000000..ffd761350 --- /dev/null +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -0,0 +1,282 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ipv4 + +import ( + "context" + "encoding/binary" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// PingProtocolName is a pseudo transport protocol used to handle ping replies. +// Use it when constructing a stack that intends to use ipv4.Ping. +const PingProtocolName = "icmpv4ping" + +// pingProtocolNumber is a fake transport protocol used to +// deliver incoming ICMP echo replies. The ICMP identifier +// number is used as a port number for multiplexing. +const pingProtocolNumber tcpip.TransportProtocolNumber = 256 + 11 + +// handleControl handles the case when an ICMP packet contains the headers of +// the original packet that caused the ICMP one to be sent. This information is +// used to find out which transport endpoint must be notified about the ICMP +// packet. +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) { + h := header.IPv4(vv.First()) + + // We don't use IsValid() here because ICMP only requires that the IP + // header plus 8 bytes of the transport header be included. So it's + // likely that it is truncated, which would cause IsValid to return + // false. + // + // Drop packet if it doesn't have the basic IPv4 header or if the + // original source address doesn't match the endpoint's address. + if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { + return + } + + hlen := int(h.HeaderLength()) + if vv.Size() < hlen || h.FragmentOffset() != 0 { + // We won't be able to handle this if it doesn't contain the + // full IPv4 header, or if it's a fragment not at offset 0 + // (because it won't have the transport header). + return + } + + // Skip the ip header, then deliver control message. + vv.TrimFront(hlen) + p := h.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) +} + +func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) { + v := vv.First() + if len(v) < header.ICMPv4MinimumSize { + return + } + h := header.ICMPv4(v) + + switch h.Type() { + case header.ICMPv4Echo: + if len(v) < header.ICMPv4EchoMinimumSize { + return + } + vv.TrimFront(header.ICMPv4MinimumSize) + req := echoRequest{r: r.Clone(), v: vv.ToView()} + select { + case e.echoRequests <- req: + default: + req.r.Release() + } + + case header.ICMPv4EchoReply: + e.dispatcher.DeliverTransportPacket(r, pingProtocolNumber, vv) + + case header.ICMPv4DstUnreachable: + if len(v) < header.ICMPv4DstUnreachableMinimumSize { + return + } + vv.TrimFront(header.ICMPv4DstUnreachableMinimumSize) + switch h.Code() { + case header.ICMPv4PortUnreachable: + e.handleControl(stack.ControlPortUnreachable, 0, vv) + + case header.ICMPv4FragmentationNeeded: + mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4DstUnreachableMinimumSize-2:])) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + } + } + // TODO: Handle other ICMP types. +} + +type echoRequest struct { + r stack.Route + v buffer.View +} + +func (e *endpoint) echoReplier() { + for req := range e.echoRequests { + sendICMPv4(&req.r, header.ICMPv4EchoReply, 0, req.v) + req.r.Release() + } +} + +func sendICMPv4(r *stack.Route, typ header.ICMPv4Type, code byte, data buffer.View) *tcpip.Error { + hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength())) + + icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpv4.SetType(typ) + icmpv4.SetCode(code) + icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) + + return r.WritePacket(&hdr, data, header.ICMPv4ProtocolNumber) +} + +// A Pinger can send echo requests to an address. +type Pinger struct { + Stack *stack.Stack + NICID tcpip.NICID + Addr tcpip.Address + LocalAddr tcpip.Address // optional + Wait time.Duration // if zero, defaults to 1 second + Count uint16 // if zero, defaults to MaxUint16 +} + +// Ping sends echo requests to an ICMPv4 endpoint. +// Responses are streamed to the channel ch. +func (p *Pinger) Ping(ctx context.Context, ch chan<- PingReply) *tcpip.Error { + count := p.Count + if count == 0 { + count = 1<<16 - 1 + } + wait := p.Wait + if wait == 0 { + wait = 1 * time.Second + } + + r, err := p.Stack.FindRoute(p.NICID, p.LocalAddr, p.Addr, ProtocolNumber) + if err != nil { + return err + } + + netProtos := []tcpip.NetworkProtocolNumber{ProtocolNumber} + ep := &pingEndpoint{ + stack: p.Stack, + pktCh: make(chan buffer.View, 1), + } + id := stack.TransportEndpointID{ + LocalAddress: r.LocalAddress, + RemoteAddress: p.Addr, + } + + _, err = p.Stack.PickEphemeralPort(func(port uint16) (bool, *tcpip.Error) { + id.LocalPort = port + err := p.Stack.RegisterTransportEndpoint(p.NICID, netProtos, pingProtocolNumber, id, ep) + switch err { + case nil: + return true, nil + case tcpip.ErrPortInUse: + return false, nil + default: + return false, err + } + }) + if err != nil { + return err + } + defer p.Stack.UnregisterTransportEndpoint(p.NICID, netProtos, pingProtocolNumber, id) + + v := buffer.NewView(4) + binary.BigEndian.PutUint16(v[0:], id.LocalPort) + + start := time.Now() + + done := make(chan struct{}) + go func(count int) { + loop: + for ; count > 0; count-- { + select { + case v := <-ep.pktCh: + seq := binary.BigEndian.Uint16(v[header.ICMPv4MinimumSize+2:]) + ch <- PingReply{ + Duration: time.Since(start) - time.Duration(seq)*wait, + SeqNumber: seq, + } + case <-ctx.Done(): + break loop + } + } + close(done) + }(int(count)) + defer func() { <-done }() + + t := time.NewTicker(wait) + defer t.Stop() + for seq := uint16(0); seq < count; seq++ { + select { + case <-t.C: + case <-ctx.Done(): + return nil + } + binary.BigEndian.PutUint16(v[2:], seq) + sent := time.Now() + if err := sendICMPv4(&r, header.ICMPv4Echo, 0, v); err != nil { + ch <- PingReply{ + Error: err, + Duration: time.Since(sent), + SeqNumber: seq, + } + } + } + return nil +} + +// PingReply summarizes an ICMP echo reply. +type PingReply struct { + Error *tcpip.Error // reports any errors sending a ping request + Duration time.Duration + SeqNumber uint16 +} + +type pingProtocol struct{} + +func (*pingProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return nil, tcpip.ErrNotSupported // endpoints are created directly +} + +func (*pingProtocol) Number() tcpip.TransportProtocolNumber { return pingProtocolNumber } + +func (*pingProtocol) MinimumPacketSize() int { return header.ICMPv4EchoMinimumSize } + +func (*pingProtocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { + ident := binary.BigEndian.Uint16(v[4:]) + return 0, ident, nil +} + +func (*pingProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) bool { + return true +} + +// SetOption implements TransportProtocol.SetOption. +func (p *pingProtocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Option implements TransportProtocol.Option. +func (p *pingProtocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +func init() { + stack.RegisterTransportProtocolFactory(PingProtocolName, func() stack.TransportProtocol { + return &pingProtocol{} + }) +} + +type pingEndpoint struct { + stack *stack.Stack + pktCh chan buffer.View +} + +func (e *pingEndpoint) Close() { + close(e.pktCh) +} + +func (e *pingEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) { + select { + case e.pktCh <- vv.ToView(): + default: + } +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (e *pingEndpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) { +} diff --git a/pkg/tcpip/network/ipv4/icmp_test.go b/pkg/tcpip/network/ipv4/icmp_test.go new file mode 100644 index 000000000..378fba74b --- /dev/null +++ b/pkg/tcpip/network/ipv4/icmp_test.go @@ -0,0 +1,124 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ipv4_test + +import ( + "context" + "testing" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +const stackAddr = "\x0a\x00\x00\x01" + +type testContext struct { + t *testing.T + linkEP *channel.Endpoint + s *stack.Stack +} + +func newTestContext(t *testing.T) *testContext { + s := stack.New([]string{ipv4.ProtocolName}, []string{ipv4.PingProtocolName}) + + const defaultMTU = 65536 + id, linkEP := channel.New(256, defaultMTU, "") + if testing.Verbose() { + id = sniffer.New(id) + } + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{{ + Destination: "\x00\x00\x00\x00", + Mask: "\x00\x00\x00\x00", + Gateway: "", + NIC: 1, + }}) + + return &testContext{ + t: t, + s: s, + linkEP: linkEP, + } +} + +func (c *testContext) cleanup() { + close(c.linkEP.C) +} + +func (c *testContext) loopback() { + go func() { + for pkt := range c.linkEP.C { + v := make(buffer.View, len(pkt.Header)+len(pkt.Payload)) + copy(v, pkt.Header) + copy(v[len(pkt.Header):], pkt.Payload) + vv := v.ToVectorisedView([1]buffer.View{}) + c.linkEP.Inject(pkt.Proto, &vv) + } + }() +} + +func TestEcho(t *testing.T) { + c := newTestContext(t) + defer c.cleanup() + c.loopback() + + ch := make(chan ipv4.PingReply, 1) + p := ipv4.Pinger{ + Stack: c.s, + NICID: 1, + Addr: stackAddr, + Wait: 10 * time.Millisecond, + Count: 1, // one ping only + } + if err := p.Ping(context.Background(), ch); err != nil { + t.Fatalf("icmp.Ping failed: %v", err) + } + + ping := <-ch + if ping.Error != nil { + t.Errorf("bad ping response: %v", ping.Error) + } +} + +func TestEchoSequence(t *testing.T) { + c := newTestContext(t) + defer c.cleanup() + c.loopback() + + const numPings = 3 + ch := make(chan ipv4.PingReply, numPings) + p := ipv4.Pinger{ + Stack: c.s, + NICID: 1, + Addr: stackAddr, + Wait: 10 * time.Millisecond, + Count: numPings, + } + if err := p.Ping(context.Background(), ch); err != nil { + t.Fatalf("icmp.Ping failed: %v", err) + } + + for i := uint16(0); i < numPings; i++ { + ping := <-ch + if ping.Error != nil { + t.Errorf("i=%d bad ping response: %v", i, ping.Error) + } + if ping.SeqNumber != i { + t.Errorf("SeqNumber=%d, want %d", ping.SeqNumber, i) + } + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go new file mode 100644 index 000000000..4cc2a2fd4 --- /dev/null +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -0,0 +1,233 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package ipv4 contains the implementation of the ipv4 network protocol. To use +// it in the networking stack, this package must be added to the project, and +// activated on the stack by passing ipv4.ProtocolName (or "ipv4") as one of the +// network protocols when calling stack.New(). Then endpoints can be created +// by passing ipv4.ProtocolNumber as the network protocol number when calling +// Stack.NewEndpoint(). +package ipv4 + +import ( + "sync/atomic" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/fragmentation" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/hash" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +const ( + // ProtocolName is the string representation of the ipv4 protocol name. + ProtocolName = "ipv4" + + // ProtocolNumber is the ipv4 protocol number. + ProtocolNumber = header.IPv4ProtocolNumber + + // maxTotalSize is maximum size that can be encoded in the 16-bit + // TotalLength field of the ipv4 header. + maxTotalSize = 0xffff + + // buckets is the number of identifier buckets. + buckets = 2048 +) + +type address [header.IPv4AddressSize]byte + +type endpoint struct { + nicid tcpip.NICID + id stack.NetworkEndpointID + address address + linkEP stack.LinkEndpoint + dispatcher stack.TransportDispatcher + echoRequests chan echoRequest + fragmentation *fragmentation.Fragmentation +} + +func newEndpoint(nicid tcpip.NICID, addr tcpip.Address, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) *endpoint { + e := &endpoint{ + nicid: nicid, + linkEP: linkEP, + dispatcher: dispatcher, + echoRequests: make(chan echoRequest, 10), + fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), + } + copy(e.address[:], addr) + e.id = stack.NetworkEndpointID{tcpip.Address(e.address[:])} + + go e.echoReplier() + + return e +} + +// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus +// the network layer max header length. +func (e *endpoint) MTU() uint32 { + return calculateMTU(e.linkEP.MTU()) +} + +// Capabilities implements stack.NetworkEndpoint.Capabilities. +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.linkEP.Capabilities() +} + +// NICID returns the ID of the NIC this endpoint belongs to. +func (e *endpoint) NICID() tcpip.NICID { + return e.nicid +} + +// ID returns the ipv4 endpoint ID. +func (e *endpoint) ID() *stack.NetworkEndpointID { + return &e.id +} + +// MaxHeaderLength returns the maximum length needed by ipv4 headers (and +// underlying protocols). +func (e *endpoint) MaxHeaderLength() uint16 { + return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize +} + +// WritePacket writes a packet to the given destination address and protocol. +func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error { + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + length := uint16(hdr.UsedLength() + len(payload)) + id := uint32(0) + if length > header.IPv4MaximumHeaderSize+8 { + // Packets of 68 bytes or less are required by RFC 791 to not be + // fragmented, so we only assign ids to larger packets. + id = atomic.AddUint32(&ids[hashRoute(r, protocol)%buckets], 1) + } + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: length, + ID: uint16(id), + TTL: 65, + Protocol: uint8(protocol), + SrcAddr: tcpip.Address(e.address[:]), + DstAddr: r.RemoteAddress, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber) +} + +// HandlePacket is called by the link layer when new ipv4 packets arrive for +// this endpoint. +func (e *endpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) { + h := header.IPv4(vv.First()) + if !h.IsValid(vv.Size()) { + return + } + + hlen := int(h.HeaderLength()) + tlen := int(h.TotalLength()) + vv.TrimFront(hlen) + vv.CapLength(tlen - hlen) + + more := (h.Flags() & header.IPv4FlagMoreFragments) != 0 + if more || h.FragmentOffset() != 0 { + // The packet is a fragment, let's try to reassemble it. + last := h.FragmentOffset() + uint16(vv.Size()) - 1 + tt, ready := e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv) + if !ready { + return + } + vv = &tt + } + p := h.TransportProtocol() + if p == header.ICMPv4ProtocolNumber { + e.handleICMP(r, vv) + return + } + e.dispatcher.DeliverTransportPacket(r, p, vv) +} + +// Close cleans up resources associated with the endpoint. +func (e *endpoint) Close() { + close(e.echoRequests) +} + +type protocol struct{} + +// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported +// only for tests that short-circuit the stack. Regular use of the protocol is +// done via the stack, which gets a protocol descriptor from the init() function +// below. +func NewProtocol() stack.NetworkProtocol { + return &protocol{} +} + +// Number returns the ipv4 protocol number. +func (p *protocol) Number() tcpip.NetworkProtocolNumber { + return ProtocolNumber +} + +// MinimumPacketSize returns the minimum valid ipv4 packet size. +func (p *protocol) MinimumPacketSize() int { + return header.IPv4MinimumSize +} + +// ParseAddresses implements NetworkProtocol.ParseAddresses. +func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + h := header.IPv4(v) + return h.SourceAddress(), h.DestinationAddress() +} + +// NewEndpoint creates a new ipv4 endpoint. +func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { + return newEndpoint(nicid, addr, dispatcher, linkEP), nil +} + +// SetOption implements NetworkProtocol.SetOption. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Option implements NetworkProtocol.Option. +func (p *protocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// calculateMTU calculates the network-layer payload MTU based on the link-layer +// payload mtu. +func calculateMTU(mtu uint32) uint32 { + if mtu > maxTotalSize { + mtu = maxTotalSize + } + return mtu - header.IPv4MinimumSize +} + +// hashRoute calculates a hash value for the given route. It uses the source & +// destination address, the transport protocol number, and a random initial +// value (generated once on initialization) to generate the hash. +func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 { + t := r.LocalAddress + a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + t = r.RemoteAddress + b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + return hash.Hash3Words(a, b, uint32(protocol), hashIV) +} + +var ( + ids []uint32 + hashIV uint32 +) + +func init() { + ids = make([]uint32, buckets) + + // Randomly initialize hashIV and the ids. + r := hash.RandN32(1 + buckets) + for i := range ids { + ids[i] = r[i] + } + hashIV = r[buckets] + + stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { + return &protocol{} + }) +} |