diff options
Diffstat (limited to 'pkg/tcpip')
62 files changed, 4048 insertions, 931 deletions
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index ea0c5413d..8db70a700 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -84,8 +84,8 @@ type VectorisedView struct { size int } -// NewVectorisedView creates a new vectorised view from an already-allocated slice -// of View and sets its size. +// NewVectorisedView creates a new vectorised view from an already-allocated +// slice of View and sets its size. func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } @@ -170,8 +170,9 @@ func (vv *VectorisedView) CapLength(length int) { } // Clone returns a clone of this VectorisedView. -// If the buffer argument is large enough to contain all the Views of this VectorisedView, -// the method will avoid allocations and use the buffer to store the Views of the clone. +// If the buffer argument is large enough to contain all the Views of this +// VectorisedView, the method will avoid allocations and use the buffer to +// store the Views of the clone. func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } @@ -209,7 +210,8 @@ func (vv *VectorisedView) PullUp(count int) (View, bool) { return newFirst, true } -// Size returns the size in bytes of the entire content stored in the vectorised view. +// Size returns the size in bytes of the entire content stored in the +// vectorised view. func (vv *VectorisedView) Size() int { return vv.size } @@ -222,6 +224,12 @@ func (vv *VectorisedView) ToView() View { if len(vv.views) == 1 { return vv.views[0] } + return vv.ToOwnedView() +} + +// ToOwnedView returns a single view containing the content of the vectorised +// view that vv does not own. +func (vv *VectorisedView) ToOwnedView() View { u := make([]byte, 0, vv.size) for _, v := range vv.views { u = append(u, v...) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 19627fa9b..d4d785cca 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -118,18 +118,82 @@ func TTL(ttl uint8) NetworkChecker { v = ip.HopLimit() } if v != ttl { - t.Fatalf("Bad TTL, got %v, want %v", v, ttl) + t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl) + } + } +} + +// IPFullLength creates a checker for the full IP packet length. The +// expected size is checked against both the Total Length in the +// header and the number of bytes received. +func IPFullLength(packetLength uint16) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + var v uint16 + var l uint16 + switch ip := h[0].(type) { + case header.IPv4: + v = ip.TotalLength() + l = uint16(len(ip)) + case header.IPv6: + v = ip.PayloadLength() + header.IPv6FixedHeaderSize + l = uint16(len(ip)) + default: + t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4 or header.IPv6", ip) + } + if l != packetLength { + t.Errorf("bad packet length, got = %d, want = %d", l, packetLength) + } + if v != packetLength { + t.Errorf("unexpected packet length in header, got = %d, want = %d", v, packetLength) + } + } +} + +// IPv4HeaderLength creates a checker that checks the IPv4 Header length. +func IPv4HeaderLength(headerLength int) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + switch ip := h[0].(type) { + case header.IPv4: + if hl := ip.HeaderLength(); hl != uint8(headerLength) { + t.Errorf("Bad header length, got = %d, want = %d", hl, headerLength) + } + default: + t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", ip) } } } // PayloadLen creates a checker that checks the payload length. -func PayloadLen(plen int) NetworkChecker { +func PayloadLen(payloadLength int) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() - if l := len(h[0].Payload()); l != plen { - t.Errorf("Bad payload length, got %v, want %v", l, plen) + if l := len(h[0].Payload()); l != payloadLength { + t.Errorf("Bad payload length, got = %d, want = %d", l, payloadLength) + } + } +} + +// IPv4Options returns a checker that checks the options in an IPv4 packet. +func IPv4Options(want []byte) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + ip, ok := h[0].(header.IPv4) + if !ok { + t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) + } + options := ip.Options() + // cmp.Diff does not consider nil slices equal to empty slices, but we do. + if len(want) == 0 && len(options) == 0 { + return + } + if diff := cmp.Diff(want, options); diff != "" { + t.Errorf("options mismatch (-want +got):\n%s", diff) } } } @@ -139,11 +203,11 @@ func FragmentOffset(offset uint16) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() - // We only do this of IPv4 for now. + // We only do this for IPv4 for now. switch ip := h[0].(type) { case header.IPv4: if v := ip.FragmentOffset(); v != offset { - t.Errorf("Bad fragment offset, got %v, want %v", v, offset) + t.Errorf("Bad fragment offset, got = %d, want = %d", v, offset) } } } @@ -154,11 +218,11 @@ func FragmentFlags(flags uint8) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() - // We only do this of IPv4 for now. + // We only do this for IPv4 for now. switch ip := h[0].(type) { case header.IPv4: if v := ip.Flags(); v != flags { - t.Errorf("Bad fragment offset, got %v, want %v", v, flags) + t.Errorf("Bad fragment offset, got = %d, want = %d", v, flags) } } } @@ -208,7 +272,7 @@ func TOS(tos uint8, label uint32) NetworkChecker { t.Helper() if v, l := h[0].TOS(); v != tos || l != label { - t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label) + t.Errorf("Bad TOS, got = (%d, %d), want = (%d,%d)", v, l, tos, label) } } } @@ -234,7 +298,7 @@ func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker { t.Helper() if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader { - t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) + t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber) } ipv6Frag := header.IPv6Fragment(h[0].Payload()) @@ -261,7 +325,7 @@ func TCP(checkers ...TransportChecker) NetworkChecker { last := h[len(h)-1] if p := last.TransportProtocol(); p != header.TCPProtocolNumber { - t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber) + t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber) } // Verify the checksum. @@ -297,7 +361,7 @@ func UDP(checkers ...TransportChecker) NetworkChecker { last := h[len(h)-1] if p := last.TransportProtocol(); p != header.UDPProtocolNumber { - t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) + t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber) } udp := header.UDP(last.Payload()) @@ -316,7 +380,7 @@ func SrcPort(port uint16) TransportChecker { t.Helper() if p := h.SourcePort(); p != port { - t.Errorf("Bad source port, got %v, want %v", p, port) + t.Errorf("Bad source port, got = %d, want = %d", p, port) } } } @@ -327,7 +391,7 @@ func DstPort(port uint16) TransportChecker { t.Helper() if p := h.DestinationPort(); p != port { - t.Errorf("Bad destination port, got %v, want %v", p, port) + t.Errorf("Bad destination port, got = %d, want = %d", p, port) } } } @@ -359,7 +423,7 @@ func TCPSeqNum(seq uint32) TransportChecker { } if s := tcp.SequenceNumber(); s != seq { - t.Errorf("Bad sequence number, got %v, want %v", s, seq) + t.Errorf("Bad sequence number, got = %d, want = %d", s, seq) } } } @@ -375,7 +439,7 @@ func TCPAckNum(seq uint32) TransportChecker { } if s := tcp.AckNumber(); s != seq { - t.Errorf("Bad ack number, got %v, want %v", s, seq) + t.Errorf("Bad ack number, got = %d, want = %d", s, seq) } } } @@ -492,7 +556,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { case header.TCPOptionMSS: v := uint16(opts[i+2])<<8 | uint16(opts[i+3]) if wantOpts.MSS != v { - t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS) + t.Errorf("Bad MSS, got = %d, want = %d", v, wantOpts.MSS) } foundMSS = true i += 4 @@ -502,7 +566,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { } v := int(opts[i+2]) if v != wantOpts.WS { - t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS) + t.Errorf("Bad WS, got = %d, want = %d", v, wantOpts.WS) } foundWS = true i += 3 @@ -551,7 +615,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { t.Error("TS option specified but the timestamp value is zero") } if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 { - t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr) + t.Errorf("TS option specified but TSEcr is incorrect, got = %d, want = %d", tsEcr, wantOpts.TSEcr) } if wantOpts.SACKPermitted && !foundSACKPermitted { t.Errorf("SACKPermitted option not found. Options: %x", opts) @@ -589,7 +653,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i) } if opts[i+1] != 10 { - t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1]) + t.Errorf("TS option found, but bad length specified: got = %d, want = 10", opts[i+1]) } tsVal = binary.BigEndian.Uint32(opts[i+2:]) tsEcr = binary.BigEndian.Uint32(opts[i+6:]) @@ -609,19 +673,19 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp } if wantTS != foundTS { - t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS) + t.Errorf("TS Option mismatch, got TS= %t, want TS= %t", foundTS, wantTS) } if wantTS && wantTSVal != 0 && wantTSVal != tsVal { - t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal) + t.Errorf("Timestamp value is incorrect, got = %d, want = %d", tsVal, wantTSVal) } if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr { - t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr) + t.Errorf("Timestamp Echo Reply is incorrect, got = %d, want = %d", tsEcr, wantTSEcr) } } } -// TCPNoSACKBlockChecker creates a checker that verifies that the segment does not -// contain any SACK blocks in the TCP options. +// TCPNoSACKBlockChecker creates a checker that verifies that the segment does +// not contain any SACK blocks in the TCP options. func TCPNoSACKBlockChecker() TransportChecker { return TCPSACKBlockChecker(nil) } @@ -679,7 +743,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { } if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) { - t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks) + t.Errorf("SACKBlocks are not equal, got = %v, want = %v", gotSACKBlocks, sackBlocks) } } } @@ -695,8 +759,8 @@ func Payload(want []byte) TransportChecker { } } -// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and -// potentially additional ICMPv4 header fields. +// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 +// and potentially additional ICMPv4 header fields. func ICMPv4(checkers ...TransportChecker) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() @@ -724,10 +788,10 @@ func ICMPv4Type(want header.ICMPv4Type) TransportChecker { icmpv4, ok := h.(header.ICMPv4) if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) } if got := icmpv4.Type(); got != want { - t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) + t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want) } } } @@ -739,10 +803,76 @@ func ICMPv4Code(want header.ICMPv4Code) TransportChecker { icmpv4, ok := h.(header.ICMPv4) if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) } if got := icmpv4.Code(); got != want { - t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) + t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want) + } + } +} + +// ICMPv4Ident creates a checker that checks the ICMPv4 echo Ident. +func ICMPv4Ident(want uint16) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + if got := icmpv4.Ident(); got != want { + t.Fatalf("unexpected ICMP ident, got = %d, want = %d", got, want) + } + } +} + +// ICMPv4Seq creates a checker that checks the ICMPv4 echo Sequence. +func ICMPv4Seq(want uint16) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + if got := icmpv4.Sequence(); got != want { + t.Fatalf("unexpected ICMP sequence, got = %d, want = %d", got, want) + } + } +} + +// ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum. +// This assumes that the payload exactly makes up the rest of the slice. +func ICMPv4Checksum() TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + heldChecksum := icmpv4.Checksum() + icmpv4.SetChecksum(0) + newChecksum := ^header.Checksum(icmpv4, 0) + icmpv4.SetChecksum(heldChecksum) + if heldChecksum != newChecksum { + t.Errorf("unexpected ICMP checksum, got = %d, want = %d", heldChecksum, newChecksum) + } + } +} + +// ICMPv4Payload creates a checker that checks the payload in an ICMPv4 packet. +func ICMPv4Payload(want []byte) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + payload := icmpv4.Payload() + if diff := cmp.Diff(want, payload); diff != "" { + t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) } } } @@ -782,10 +912,10 @@ func ICMPv6Type(want header.ICMPv6Type) TransportChecker { icmpv6, ok := h.(header.ICMPv6) if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) } if got := icmpv6.Type(); got != want { - t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) + t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want) } } } @@ -797,10 +927,10 @@ func ICMPv6Code(want header.ICMPv6Code) TransportChecker { icmpv6, ok := h.(header.ICMPv6) if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) } if got := icmpv6.Code(); got != want { - t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) + t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want) } } } diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go index eaface8cb..95ade0e5c 100644 --- a/pkg/tcpip/header/eth.go +++ b/pkg/tcpip/header/eth.go @@ -117,25 +117,31 @@ func (b Ethernet) Encode(e *EthernetFields) { copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr) } -// IsValidUnicastEthernetAddress returns true if addr is a valid unicast +// IsMulticastEthernetAddress returns true if the address is a multicast +// ethernet address. +func IsMulticastEthernetAddress(addr tcpip.LinkAddress) bool { + if len(addr) != EthernetAddressSize { + return false + } + + return addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 +} + +// IsValidUnicastEthernetAddress returns true if the address is a unicast // ethernet address. func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool { - // Must be of the right length. if len(addr) != EthernetAddressSize { return false } - // Must not be unspecified. if addr == unspecifiedEthernetAddress { return false } - // Must not be a multicast. if addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 { return false } - // addr is a valid unicast ethernet address. return true } diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go index 14413f2ce..3bc8b2b21 100644 --- a/pkg/tcpip/header/eth_test.go +++ b/pkg/tcpip/header/eth_test.go @@ -67,6 +67,53 @@ func TestIsValidUnicastEthernetAddress(t *testing.T) { } } +func TestIsMulticastEthernetAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.LinkAddress + expected bool + }{ + { + "Nil", + tcpip.LinkAddress([]byte(nil)), + false, + }, + { + "Empty", + tcpip.LinkAddress(""), + false, + }, + { + "InvalidLength", + tcpip.LinkAddress("\x01\x02\x03"), + false, + }, + { + "Unspecified", + unspecifiedEthernetAddress, + false, + }, + { + "Multicast", + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + true, + }, + { + "Unicast", + tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"), + false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := IsMulticastEthernetAddress(test.addr); got != test.expected { + t.Fatalf("got IsMulticastEthernetAddress = %t, want = %t", got, test.expected) + } + }) + } +} + func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) { tests := []struct { name string diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index c00bcadfb..504408878 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -126,15 +126,6 @@ func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) } // SetCode sets the ICMP code field. func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) } -// SetPointer sets the pointer field in a Parameter error packet. -// This is the first byte of the type specific data field. -func (b ICMPv4) SetPointer(c byte) { b[icmpv4PointerOffset] = c } - -// SetTypeSpecific sets the full 32 bit type specific data field. -func (b ICMPv4) SetTypeSpecific(val uint32) { - binary.BigEndian.PutUint32(b[icmpv4PointerOffset:], val) -} - // Checksum is the ICMP checksum field. func (b ICMPv4) Checksum() uint16 { return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:]) diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index 4eb5abd79..6be31beeb 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -156,9 +156,14 @@ const ( // ICMP codes used with Parameter Problem (Type 4). As per RFC 4443 section 3.4. const ( + // ICMPv6ErroneousHeader indicates an erroneous header field was encountered. ICMPv6ErroneousHeader ICMPv6Code = 0 - ICMPv6UnknownHeader ICMPv6Code = 1 - ICMPv6UnknownOption ICMPv6Code = 2 + + // ICMPv6UnknownHeader indicates an unrecognized Next Header type encountered. + ICMPv6UnknownHeader ICMPv6Code = 1 + + // ICMPv6UnknownOption indicates an unrecognized IPv6 option was encountered. + ICMPv6UnknownOption ICMPv6Code = 2 ) // ICMPv6UnusedCode is the code value used with ICMPv6 messages which don't use @@ -177,7 +182,12 @@ func (b ICMPv6) Code() ICMPv6Code { return ICMPv6Code(b[1]) } // SetCode sets the ICMP code field. func (b ICMPv6) SetCode(c ICMPv6Code) { b[1] = byte(c) } -// SetTypeSpecific sets the full 32 bit type specific data field. +// TypeSpecific returns the type specific data field. +func (b ICMPv6) TypeSpecific() uint32 { + return binary.BigEndian.Uint32(b[icmpv6PointerOffset:]) +} + +// SetTypeSpecific sets the type specific data field. func (b ICMPv6) SetTypeSpecific(val uint32) { binary.BigEndian.PutUint32(b[icmpv6PointerOffset:], val) } diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index b07d9991d..4c6e4be64 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -16,10 +16,29 @@ package header import ( "encoding/binary" + "fmt" "gvisor.dev/gvisor/pkg/tcpip" ) +// RFC 971 defines the fields of the IPv4 header on page 11 using the following +// diagram: ("Figure 4") +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |Version| IHL |Type of Service| Total Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Identification |Flags| Fragment Offset | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Time to Live | Protocol | Header Checksum | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Source Address | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Destination Address | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Options | Padding | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// const ( versIHL = 0 tos = 1 @@ -33,6 +52,7 @@ const ( checksum = 10 srcAddr = 12 dstAddr = 16 + options = 20 ) // IPv4Fields contains the fields of an IPv4 packet. It is used to describe the @@ -76,7 +96,8 @@ type IPv4Fields struct { // IPv4 represents an ipv4 header stored in a byte array. // Most of the methods of IPv4 access to the underlying slice without // checking the boundaries and could panic because of 'index out of range'. -// Always call IsValid() to validate an instance of IPv4 before using other methods. +// Always call IsValid() to validate an instance of IPv4 before using other +// methods. type IPv4 []byte const ( @@ -151,13 +172,44 @@ func IPVersion(b []byte) int { if len(b) < versIHL+1 { return -1 } - return int(b[versIHL] >> 4) + return int(b[versIHL] >> ipVersionShift) } +// RFC 791 page 11 shows the header length (IHL) is in the lower 4 bits +// of the first byte, and is counted in multiples of 4 bytes. +// +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |Version| IHL |Type of Service| Total Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// (...) +// Version: 4 bits +// The Version field indicates the format of the internet header. This +// document describes version 4. +// +// IHL: 4 bits +// Internet Header Length is the length of the internet header in 32 +// bit words, and thus points to the beginning of the data. Note that +// the minimum value for a correct header is 5. +// +const ( + ipVersionShift = 4 + ipIHLMask = 0x0f + IPv4IHLStride = 4 +) + // HeaderLength returns the value of the "header length" field of the ipv4 // header. The length returned is in bytes. func (b IPv4) HeaderLength() uint8 { - return (b[versIHL] & 0xf) * 4 + return (b[versIHL] & ipIHLMask) * IPv4IHLStride +} + +// SetHeaderLength sets the value of the "Internet Header Length" field. +func (b IPv4) SetHeaderLength(hdrLen uint8) { + if hdrLen > IPv4MaximumHeaderSize { + panic(fmt.Sprintf("got IPv4 Header size = %d, want <= %d", hdrLen, IPv4MaximumHeaderSize)) + } + b[versIHL] = (IPv4Version << ipVersionShift) | ((hdrLen / IPv4IHLStride) & ipIHLMask) } // ID returns the value of the identifier field of the ipv4 header. @@ -211,6 +263,12 @@ func (b IPv4) DestinationAddress() tcpip.Address { return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize]) } +// Options returns a a buffer holding the options. +func (b IPv4) Options() []byte { + hdrLen := b.HeaderLength() + return b[options:hdrLen:hdrLen] +} + // TransportProtocol implements Network.TransportProtocol. func (b IPv4) TransportProtocol() tcpip.TransportProtocolNumber { return tcpip.TransportProtocolNumber(b.Protocol()) @@ -236,6 +294,11 @@ func (b IPv4) SetTOS(v uint8, _ uint32) { b[tos] = v } +// SetTTL sets the "Time to Live" field of the IPv4 header. +func (b IPv4) SetTTL(v byte) { + b[ttl] = v +} + // SetTotalLength sets the "total length" field of the ipv4 header. func (b IPv4) SetTotalLength(totalLength uint16) { binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength) @@ -276,7 +339,7 @@ func (b IPv4) CalculateChecksum() uint16 { // Encode encodes all the fields of the ipv4 header. func (b IPv4) Encode(i *IPv4Fields) { - b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf) + b.SetHeaderLength(i.IHL) b[tos] = i.TOS b.SetTotalLength(i.TotalLength) binary.BigEndian.PutUint16(b[id:], i.ID) diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 0761a1807..c5d8a3456 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -34,6 +34,9 @@ const ( hopLimit = 7 v6SrcAddr = 8 v6DstAddr = v6SrcAddr + IPv6AddressSize + + // IPv6FixedHeaderSize is the size of the fixed header. + IPv6FixedHeaderSize = v6DstAddr + IPv6AddressSize ) // IPv6Fields contains the fields of an IPv6 packet. It is used to describe the @@ -69,7 +72,7 @@ type IPv6 []byte const ( // IPv6MinimumSize is the minimum size of a valid IPv6 packet. - IPv6MinimumSize = 40 + IPv6MinimumSize = IPv6FixedHeaderSize // IPv6AddressSize is the size, in bytes, of an IPv6 address. IPv6AddressSize = 16 @@ -306,14 +309,21 @@ func IsV6UnicastAddress(addr tcpip.Address) bool { return addr[0] != 0xff } +const solicitedNodeMulticastPrefix = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff" + // SolicitedNodeAddr computes the solicited-node multicast address. This is // used for NDP. Described in RFC 4291. The argument must be a full-length IPv6 // address. func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address { - const solicitedNodeMulticastPrefix = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff" return solicitedNodeMulticastPrefix + addr[len(addr)-3:] } +// IsSolicitedNodeAddr determines whether the address is a solicited-node +// multicast address. +func IsSolicitedNodeAddr(addr tcpip.Address) bool { + return solicitedNodeMulticastPrefix == addr[:len(addr)-3] +} + // EthernetAdddressToModifiedEUI64IntoBuf populates buf with a modified EUI-64 // from a 48-bit Ethernet/MAC address, as per RFC 4291 section 2.5.1. // diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go index 3499d8399..583c2c5d3 100644 --- a/pkg/tcpip/header/ipv6_extension_headers.go +++ b/pkg/tcpip/header/ipv6_extension_headers.go @@ -149,6 +149,19 @@ func (b ipv6OptionsExtHdr) Iter() IPv6OptionsExtHdrOptionsIterator { // obtained before modification is no longer used. type IPv6OptionsExtHdrOptionsIterator struct { reader bytes.Reader + + // optionOffset is the number of bytes from the first byte of the + // options field to the beginning of the current option. + optionOffset uint32 + + // nextOptionOffset is the offset of the next option. + nextOptionOffset uint32 +} + +// OptionOffset returns the number of bytes parsed while processing the +// option field of the current Extension Header. +func (i *IPv6OptionsExtHdrOptionsIterator) OptionOffset() uint32 { + return i.optionOffset } // IPv6OptionUnknownAction is the action that must be taken if the processing @@ -226,6 +239,7 @@ func (*IPv6UnknownExtHdrOption) isIPv6ExtHdrOption() {} // the options data, or an error occured. func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error) { for { + i.optionOffset = i.nextOptionOffset temp, err := i.reader.ReadByte() if err != nil { // If we can't read the first byte of a new option, then we know the @@ -238,6 +252,7 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error // know the option does not have Length and Data fields. End processing of // the Pad1 option and continue processing the buffer as a new option. if id == ipv6Pad1ExtHdrOptionIdentifier { + i.nextOptionOffset = i.optionOffset + 1 continue } @@ -254,41 +269,40 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error return nil, true, fmt.Errorf("error when reading the option's Length field for option with id = %d: %w", id, io.ErrUnexpectedEOF) } - // Special-case the variable length padding option to avoid a copy. - if id == ipv6PadNExtHdrOptionIdentifier { - // Do we have enough bytes in the reader for the PadN option? - if n := i.reader.Len(); n < int(length) { - // Reset the reader to effectively consume the remaining buffer. - i.reader.Reset(nil) - - // We return the same error as if we failed to read a non-padding option - // so consumers of this iterator don't need to differentiate between - // padding and non-padding options. - return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF) - } + // Do we have enough bytes in the reader for the next option? + if n := i.reader.Len(); n < int(length) { + // Reset the reader to effectively consume the remaining buffer. + i.reader.Reset(nil) + + // We return the same error as if we failed to read a non-padding option + // so consumers of this iterator don't need to differentiate between + // padding and non-padding options. + return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF) + } + + i.nextOptionOffset = i.optionOffset + uint32(length) + 1 /* option ID */ + 1 /* length byte */ + switch id { + case ipv6PadNExtHdrOptionIdentifier: + // Special-case the variable length padding option to avoid a copy. if _, err := i.reader.Seek(int64(length), io.SeekCurrent); err != nil { panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err)) } - - // End processing of the PadN option and continue processing the buffer as - // a new option. continue - } - - bytes := make([]byte, length) - if n, err := io.ReadFull(&i.reader, bytes); err != nil { - // io.ReadFull may return io.EOF if i.reader has been exhausted. We use - // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the - // Length field found in the option. - if err == io.EOF { - err = io.ErrUnexpectedEOF + default: + bytes := make([]byte, length) + if n, err := io.ReadFull(&i.reader, bytes); err != nil { + // io.ReadFull may return io.EOF if i.reader has been exhausted. We use + // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the + // Length field found in the option. + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + + return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err) } - - return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err) + return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil } - - return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil } } @@ -382,6 +396,29 @@ type IPv6PayloadIterator struct { // Indicates to the iterator that it should return the remaining payload as a // raw payload on the next call to Next. forceRaw bool + + // headerOffset is the offset of the beginning of the current extension + // header starting from the beginning of the fixed header. + headerOffset uint32 + + // parseOffset is the byte offset into the current extension header of the + // field we are currently examining. It can be added to the header offset + // if the absolute offset within the packet is required. + parseOffset uint32 + + // nextOffset is the offset of the next header. + nextOffset uint32 +} + +// HeaderOffset returns the offset to the start of the extension +// header most recently processed. +func (i IPv6PayloadIterator) HeaderOffset() uint32 { + return i.headerOffset +} + +// ParseOffset returns the number of bytes successfully parsed. +func (i IPv6PayloadIterator) ParseOffset() uint32 { + return i.headerOffset + i.parseOffset } // MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing @@ -397,7 +434,8 @@ func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, pa nextHdrIdentifier: nextHdrIdentifier, payload: payload.Clone(nil), // We need a buffer of size 1 for calls to bufio.Reader.ReadByte. - reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1), + reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1), + nextOffset: IPv6FixedHeaderSize, } } @@ -434,6 +472,8 @@ func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader { // Next is unable to return anything because the iterator has reached the end of // the payload, or an error occured. func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) { + i.headerOffset = i.nextOffset + i.parseOffset = 0 // We could be forced to return i as a raw header when the previous header was // a fragment extension header as the data following the fragment extension // header may not be complete. @@ -461,7 +501,7 @@ func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) { return IPv6RoutingExtHdr(bytes), false, nil case IPv6FragmentExtHdrIdentifier: var data [6]byte - // We ignore the returned bytes becauase we know the fragment extension + // We ignore the returned bytes because we know the fragment extension // header specific data will fit in data. nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:]) if err != nil { @@ -519,10 +559,12 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP if err != nil { return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err) } + i.parseOffset++ var length uint8 length, err = i.reader.ReadByte() i.payload.TrimFront(1) + if err != nil { if fragmentHdr { return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err) @@ -534,6 +576,17 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP length = 0 } + // Make parseOffset point to the first byte of the Extension Header + // specific data. + i.parseOffset++ + + // length is in 8 byte chunks but doesn't include the first one. + // See RFC 8200 for each header type, sections 4.3-4.6 and the requirement + // in section 4.8 for new extension headers at the top of page 24. + // [ Hdr Ext Len ] ... Length of the Destination Options header in 8-octet + // units, not including the first 8 octets. + i.nextOffset += uint32((length + 1) * ipv6ExtHdrLenBytesPerUnit) + bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded if bytes == nil { bytes = make([]byte, bytesLen) diff --git a/pkg/tcpip/header/ipversion_test.go b/pkg/tcpip/header/ipversion_test.go index b5540bf66..17a49d4fa 100644 --- a/pkg/tcpip/header/ipversion_test.go +++ b/pkg/tcpip/header/ipversion_test.go @@ -22,7 +22,7 @@ import ( func TestIPv4(t *testing.T) { b := header.IPv4(make([]byte, header.IPv4MinimumSize)) - b.Encode(&header.IPv4Fields{}) + b.Encode(&header.IPv4Fields{IHL: header.IPv4MinimumSize}) const want = header.IPv4Version if v := header.IPVersion(b); v != want { diff --git a/pkg/tcpip/link/pipe/BUILD b/pkg/tcpip/link/pipe/BUILD new file mode 100644 index 000000000..9f31c1ffc --- /dev/null +++ b/pkg/tcpip/link/pipe/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "pipe", + srcs = ["pipe.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go new file mode 100644 index 000000000..76f563811 --- /dev/null +++ b/pkg/tcpip/link/pipe/pipe.go @@ -0,0 +1,124 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package pipe provides the implementation of pipe-like data-link layer +// endpoints. Such endpoints allow packets to be sent between two interfaces. +package pipe + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.LinkEndpoint = (*Endpoint)(nil) + +// New returns both ends of a new pipe. +func New(linkAddr1, linkAddr2 tcpip.LinkAddress, capabilities stack.LinkEndpointCapabilities) (*Endpoint, *Endpoint) { + ep1 := &Endpoint{ + linkAddr: linkAddr1, + capabilities: capabilities, + } + ep2 := &Endpoint{ + linkAddr: linkAddr2, + linked: ep1, + capabilities: capabilities, + } + ep1.linked = ep2 + return ep1, ep2 +} + +// Endpoint is one end of a pipe. +type Endpoint struct { + capabilities stack.LinkEndpointCapabilities + linkAddr tcpip.LinkAddress + dispatcher stack.NetworkDispatcher + linked *Endpoint + onWritePacket func(*stack.PacketBuffer) +} + +// WritePacket implements stack.LinkEndpoint. +func (e *Endpoint) WritePacket(r *stack.Route, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + if !e.linked.IsAttached() { + return nil + } + + // The pipe endpoint will accept all multicast/broadcast link traffic and only + // unicast traffic destined to itself. + if len(e.linked.linkAddr) != 0 && + r.RemoteLinkAddress != e.linked.linkAddr && + r.RemoteLinkAddress != header.EthernetBroadcastAddress && + !header.IsMulticastEthernetAddress(r.RemoteLinkAddress) { + return nil + } + + e.linked.dispatcher.DeliverNetworkPacket(e.linkAddr, r.RemoteLinkAddress, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + })) + + return nil +} + +// WritePackets implements stack.LinkEndpoint. +func (*Endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + panic("not implemented") +} + +// WriteRawPacket implements stack.LinkEndpoint. +func (*Endpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { + panic("not implemented") +} + +// Attach implements stack.LinkEndpoint. +func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.dispatcher = dispatcher +} + +// IsAttached implements stack.LinkEndpoint. +func (e *Endpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// Wait implements stack.LinkEndpoint. +func (*Endpoint) Wait() {} + +// MTU implements stack.LinkEndpoint. +func (*Endpoint) MTU() uint32 { + return header.IPv6MinimumMTU +} + +// Capabilities implements stack.LinkEndpoint. +func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.capabilities +} + +// MaxHeaderLength implements stack.LinkEndpoint. +func (*Endpoint) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress implements stack.LinkEndpoint. +func (e *Endpoint) LinkAddress() tcpip.LinkAddress { + return e.linkAddr +} + +// ARPHardwareType implements stack.LinkEndpoint. +func (*Endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareEther +} + +// AddHeader implements stack.LinkEndpoint. +func (*Endpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { +} diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index b6ddbe81e..f94491026 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -76,13 +76,29 @@ func (d *Device) Release(ctx context.Context) { } } +// NICID returns the NIC ID of the device. +// +// Must only be called after the device has been attached to an endpoint. +func (d *Device) NICID() tcpip.NICID { + d.mu.RLock() + defer d.mu.RUnlock() + + if d.endpoint == nil { + panic("called NICID on a device that has not been attached") + } + + return d.endpoint.nicID +} + // SetIff services TUNSETIFF ioctl(2) request. -func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error { +// +// Returns true if a new NIC was created; false if an existing one was attached. +func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) { d.mu.Lock() defer d.mu.Unlock() if d.endpoint != nil { - return syserror.EINVAL + return false, syserror.EINVAL } // Input validations. @@ -90,7 +106,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error { isTap := flags&linux.IFF_TAP != 0 supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI) if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 { - return syserror.EINVAL + return false, syserror.EINVAL } prefix := "tun" @@ -103,32 +119,32 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error { linkCaps |= stack.CapabilityResolutionRequired } - endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps) + endpoint, created, err := attachOrCreateNIC(s, name, prefix, linkCaps) if err != nil { - return syserror.EINVAL + return false, syserror.EINVAL } d.endpoint = endpoint d.notifyHandle = d.endpoint.AddNotify(d) d.flags = flags - return nil + return created, nil } -func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) { +func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, bool, error) { for { // 1. Try to attach to an existing NIC. if name != "" { - if nic, found := s.GetNICByName(name); found { - endpoint, ok := nic.LinkEndpoint().(*tunEndpoint) + if linkEP := s.GetLinkEndpointByName(name); linkEP != nil { + endpoint, ok := linkEP.(*tunEndpoint) if !ok { // Not a NIC created by tun device. - return nil, syserror.EOPNOTSUPP + return nil, false, syserror.EOPNOTSUPP } if !endpoint.TryIncRef() { // Race detected: NIC got deleted in between. continue } - return endpoint, nil + return endpoint, false, nil } } @@ -151,12 +167,12 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE }) switch err { case nil: - return endpoint, nil + return endpoint, true, nil case tcpip.ErrDuplicateNICID: // Race detected: A NIC has been created in between. continue default: - return nil, syserror.EINVAL + return nil, false, syserror.EINVAL } } } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index b47a7be51..7df77c66e 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -49,7 +49,6 @@ type endpoint struct { enabled uint32 nic stack.NetworkInterface - linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache nud stack.NUDHandler } @@ -92,12 +91,12 @@ func (e *endpoint) DefaultTTL() uint8 { } func (e *endpoint) MTU() uint32 { - lmtu := e.linkEP.MTU() + lmtu := e.nic.MTU() return lmtu - uint32(e.MaxHeaderLength()) } func (e *endpoint) MaxHeaderLength() uint16 { - return e.linkEP.MaxHeaderLength() + header.ARPSize + return e.nic.MaxHeaderLength() + header.ARPSize } func (e *endpoint) Close() { @@ -154,17 +153,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol) } - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(e.linkEP.MaxHeaderLength()) + header.ARPSize, + // As per RFC 826, under Packet Reception: + // Swap hardware and protocol fields, putting the local hardware and + // protocol addresses in the sender fields. + // + // Send the packet to the (new) target hardware address on the same + // hardware on which the request was received. + origSender := h.HardwareAddressSender() + r.RemoteLinkAddress = tcpip.LinkAddress(origSender) + respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, }) - packet := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) + packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize)) packet.SetIPv4OverEthernet() packet.SetOp(header.ARPReply) copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) - copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) + copy(packet.HardwareAddressTarget(), origSender) copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - _ = e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) + _ = e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, respPkt) case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) @@ -207,7 +214,6 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L e := &endpoint{ protocol: p, nic: nic, - linkEP: nic.LinkEndpoint(), linkAddrCache: linkAddrCache, nud: nud, } @@ -223,6 +229,7 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { r := &stack.Route{ + NetProto: ProtocolNumber, RemoteLinkAddress: remoteLinkAddr, } if len(r.RemoteLinkAddress) == 0 { diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index e247f06a4..47fb63290 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -29,6 +29,8 @@ go_library( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", ], ) @@ -44,5 +46,7 @@ go_test( deps = [ "//pkg/tcpip/buffer", "//pkg/tcpip/faketime", + "//pkg/tcpip/network/testutil", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index e1909fab0..ed502a473 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -13,7 +13,7 @@ // limitations under the License. // Package fragmentation contains the implementation of IP fragmentation. -// It is based on RFC 791 and RFC 815. +// It is based on RFC 791, RFC 815 and RFC 8200. package fragmentation import ( @@ -25,12 +25,10 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( - // DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time. - DefaultReassembleTimeout = 30 * time.Second - // HighFragThreshold is the threshold at which we start trimming old // fragmented packets. Linux uses a default value of 4 MB. See // net.ipv4.ipfrag_high_thresh for more information. @@ -243,3 +241,78 @@ func (f *Fragmentation) releaseReassemblersLocked() { f.release(r) } } + +// PacketFragmenter is the book-keeping struct for packet fragmentation. +type PacketFragmenter struct { + transportHeader buffer.View + data buffer.VectorisedView + reserve int + innerMTU int + fragmentCount int + currentFragment int + fragmentOffset int +} + +// MakePacketFragmenter prepares the struct needed for packet fragmentation. +// +// pkt is the packet to be fragmented. +// +// innerMTU is the maximum number of bytes of fragmentable data a fragment can +// have. +// +// reserve is the number of bytes that should be reserved for the headers in +// each generated fragment. +func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) PacketFragmenter { + // As per RFC 8200 Section 4.5, some IPv6 extension headers should not be + // repeated in each fragment. However we do not currently support any header + // of that kind yet, so the following computation is valid for both IPv4 and + // IPv6. + // TODO(gvisor.dev/issue/3912): Once Authentication or ESP Headers are + // supported for outbound packets, the fragmentable data should not include + // these headers. + var fragmentableData buffer.VectorisedView + fragmentableData.AppendView(pkt.TransportHeader().View()) + fragmentableData.Append(pkt.Data) + fragmentCount := (fragmentableData.Size() + innerMTU - 1) / innerMTU + + return PacketFragmenter{ + data: fragmentableData, + reserve: reserve, + innerMTU: innerMTU, + fragmentCount: fragmentCount, + } +} + +// BuildNextFragment returns a packet with the payload of the next fragment, +// along with the fragment's offset, the number of bytes copied and a boolean +// indicating if there are more fragments left or not. If this function is +// called again after it indicated that no more fragments were left, it will +// panic. +// +// Note that the returned packet will not have its network and link headers +// populated, but space for them will be reserved. The transport header will be +// stored in the packet's data. +func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, bool) { + if pf.currentFragment >= pf.fragmentCount { + panic("BuildNextFragment should not be called again after the last fragment was returned") + } + + fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: pf.reserve, + }) + + // Copy data for the fragment. + copied := pf.data.ReadToVV(&fragPkt.Data, pf.innerMTU) + + offset := pf.fragmentOffset + pf.fragmentOffset += copied + pf.currentFragment++ + more := pf.currentFragment != pf.fragmentCount + + return fragPkt, offset, copied, more +} + +// RemainingFragmentCount returns the number of fragments left to be built. +func (pf *PacketFragmenter) RemainingFragmentCount() int { + return pf.fragmentCount - pf.currentFragment +} diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index 189b223c5..d3c7d7f92 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -20,10 +20,16 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/network/testutil" ) +// reassembleTimeout is dummy timeout used for testing, where the clock never +// advances. +const reassembleTimeout = 1 + // vv is a helper to build VectorisedView from different strings. func vv(size int, pieces ...string) buffer.VectorisedView { views := make([]buffer.View, len(pieces)) @@ -96,7 +102,7 @@ var processTestCases = []struct { func TestFragmentationProcess(t *testing.T) { for _, c := range processTestCases { t.Run(c.comment, func(t *testing.T) { - f := NewFragmentation(minBlockSize, 1024, 512, DefaultReassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}) firstFragmentProto := c.in[0].proto for i, in := range c.in { vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv) @@ -251,7 +257,7 @@ func TestReassemblingTimeout(t *testing.T) { } func TestMemoryLimits(t *testing.T) { - f := NewFragmentation(minBlockSize, 3, 1, DefaultReassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}) // Send first fragment with id = 0. f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0")) // Send first fragment with id = 1. @@ -275,7 +281,7 @@ func TestMemoryLimits(t *testing.T) { } func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - f := NewFragmentation(minBlockSize, 1, 0, DefaultReassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}) // Send first fragment with id = 0. f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0")) // Send the same packet again. @@ -370,7 +376,7 @@ func TestErrors(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, DefaultReassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}) _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data)) if !errors.Is(err, test.err) { t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) @@ -381,3 +387,113 @@ func TestErrors(t *testing.T) { }) } } + +type fragmentInfo struct { + remaining int + copied int + offset int + more bool +} + +func TestPacketFragmenter(t *testing.T) { + const ( + reserve = 60 + proto = 0 + ) + + tests := []struct { + name string + innerMTU int + transportHeaderLen int + payloadSize int + wantFragments []fragmentInfo + }{ + { + name: "Packet exactly fits in MTU", + innerMTU: 1280, + transportHeaderLen: 0, + payloadSize: 1280, + wantFragments: []fragmentInfo{ + {remaining: 0, copied: 1280, offset: 0, more: false}, + }, + }, + { + name: "Packet exactly does not fit in MTU", + innerMTU: 1000, + transportHeaderLen: 0, + payloadSize: 1001, + wantFragments: []fragmentInfo{ + {remaining: 1, copied: 1000, offset: 0, more: true}, + {remaining: 0, copied: 1, offset: 1000, more: false}, + }, + }, + { + name: "Packet has a transport header", + innerMTU: 560, + transportHeaderLen: 40, + payloadSize: 560, + wantFragments: []fragmentInfo{ + {remaining: 1, copied: 560, offset: 0, more: true}, + {remaining: 0, copied: 40, offset: 560, more: false}, + }, + }, + { + name: "Packet has a huge transport header", + innerMTU: 500, + transportHeaderLen: 1300, + payloadSize: 500, + wantFragments: []fragmentInfo{ + {remaining: 3, copied: 500, offset: 0, more: true}, + {remaining: 2, copied: 500, offset: 500, more: true}, + {remaining: 1, copied: 500, offset: 1000, more: true}, + {remaining: 0, copied: 300, offset: 1500, more: false}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto) + var originalPayload buffer.VectorisedView + originalPayload.AppendView(pkt.TransportHeader().View()) + originalPayload.Append(pkt.Data) + var reassembledPayload buffer.VectorisedView + pf := MakePacketFragmenter(pkt, test.innerMTU, reserve) + for i := 0; ; i++ { + fragPkt, offset, copied, more := pf.BuildNextFragment() + wantFragment := test.wantFragments[i] + if got := pf.RemainingFragmentCount(); got != wantFragment.remaining { + t.Errorf("(fragment #%d) got pf.RemainingFragmentCount() = %d, want = %d", i, got, wantFragment.remaining) + } + if copied != wantFragment.copied { + t.Errorf("(fragment #%d) got copied = %d, want = %d", i, copied, wantFragment.copied) + } + if offset != wantFragment.offset { + t.Errorf("(fragment #%d) got offset = %d, want = %d", i, offset, wantFragment.offset) + } + if more != wantFragment.more { + t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more) + } + if got := fragPkt.Size(); got > test.innerMTU { + t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.innerMTU) + } + if got := fragPkt.AvailableHeaderBytes(); got != reserve { + t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve) + } + if got := fragPkt.TransportHeader().View().Size(); got != 0 { + t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got) + } + reassembledPayload.Append(fragPkt.Data) + if !more { + if i != len(test.wantFragments)-1 { + t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1) + } + break + } + } + if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload.ToView()); diff != "" { + t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 6861cfdaf..d436873b6 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -270,7 +270,7 @@ func buildDummyStack(t *testing.T) *stack.Stack { var _ stack.NetworkInterface = (*testInterface)(nil) type testInterface struct { - tester testObject + testObject mu struct { sync.RWMutex @@ -302,10 +302,6 @@ func (t *testInterface) setEnabled(v bool) { t.mu.disabled = !v } -func (t *testInterface) LinkEndpoint() stack.LinkEndpoint { - return &t.tester -} - func TestSourceAddressValidation(t *testing.T) { rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) { totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize @@ -517,7 +513,7 @@ func TestIPv4Send(t *testing.T) { s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) nic := testInterface{ - tester: testObject{ + testObject: testObject{ t: t, v4: true, }, @@ -538,10 +534,10 @@ func TestIPv4Send(t *testing.T) { }) // Issue the write. - nic.tester.protocol = 123 - nic.tester.srcAddr = localIPv4Addr - nic.tester.dstAddr = remoteIPv4Addr - nic.tester.contents = payload + nic.testObject.protocol = 123 + nic.testObject.srcAddr = localIPv4Addr + nic.testObject.dstAddr = remoteIPv4Addr + nic.testObject.contents = payload r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) if err != nil { @@ -560,12 +556,12 @@ func TestIPv4Receive(t *testing.T) { s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) nic := testInterface{ - tester: testObject{ + testObject: testObject{ t: t, v4: true, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) + ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -590,10 +586,10 @@ func TestIPv4Receive(t *testing.T) { } // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - nic.tester.protocol = 10 - nic.tester.srcAddr = remoteIPv4Addr - nic.tester.dstAddr = localIPv4Addr - nic.tester.contents = view[header.IPv4MinimumSize:totalLen] + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv4Addr + nic.testObject.dstAddr = localIPv4Addr + nic.testObject.contents = view[header.IPv4MinimumSize:totalLen] r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) if err != nil { @@ -606,8 +602,8 @@ func TestIPv4Receive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if nic.tester.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) + if nic.testObject.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } } @@ -640,11 +636,11 @@ func TestIPv4ReceiveControl(t *testing.T) { s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) nic := testInterface{ - tester: testObject{ + testObject: testObject{ t: t, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) + ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -691,16 +687,16 @@ func TestIPv4ReceiveControl(t *testing.T) { // Give packet to IPv4 endpoint, dispatcher will validate that // it's ok. - nic.tester.protocol = 10 - nic.tester.srcAddr = remoteIPv4Addr - nic.tester.dstAddr = localIPv4Addr - nic.tester.contents = view[dataOffset:] - nic.tester.typ = c.expectedTyp - nic.tester.extra = c.expectedExtra + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv4Addr + nic.testObject.dstAddr = localIPv4Addr + nic.testObject.contents = view[dataOffset:] + nic.testObject.typ = c.expectedTyp + nic.testObject.extra = c.expectedExtra ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize)) - if want := c.expectedCount; nic.tester.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want) + if want := c.expectedCount; nic.testObject.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) } }) } @@ -710,12 +706,12 @@ func TestIPv4FragmentationReceive(t *testing.T) { s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) nic := testInterface{ - tester: testObject{ + testObject: testObject{ t: t, v4: true, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) + ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -758,10 +754,10 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - nic.tester.protocol = 10 - nic.tester.srcAddr = remoteIPv4Addr - nic.tester.dstAddr = localIPv4Addr - nic.tester.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv4Addr + nic.testObject.dstAddr = localIPv4Addr + nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) if err != nil { @@ -776,8 +772,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if nic.tester.dataCalls != 0 { - t.Fatalf("Bad number of data calls: got %x, want 0", nic.tester.dataCalls) + if nic.testObject.dataCalls != 0 { + t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls) } // Send second segment. @@ -788,8 +784,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if nic.tester.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) + if nic.testObject.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } } @@ -797,7 +793,7 @@ func TestIPv6Send(t *testing.T) { s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) nic := testInterface{ - tester: testObject{ + testObject: testObject{ t: t, }, } @@ -821,10 +817,10 @@ func TestIPv6Send(t *testing.T) { }) // Issue the write. - nic.tester.protocol = 123 - nic.tester.srcAddr = localIPv6Addr - nic.tester.dstAddr = remoteIPv6Addr - nic.tester.contents = payload + nic.testObject.protocol = 123 + nic.testObject.srcAddr = localIPv6Addr + nic.testObject.dstAddr = remoteIPv6Addr + nic.testObject.contents = payload r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) if err != nil { @@ -843,11 +839,11 @@ func TestIPv6Receive(t *testing.T) { s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) nic := testInterface{ - tester: testObject{ + testObject: testObject{ t: t, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) + ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -871,10 +867,10 @@ func TestIPv6Receive(t *testing.T) { } // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. - nic.tester.protocol = 10 - nic.tester.srcAddr = remoteIPv6Addr - nic.tester.dstAddr = localIPv6Addr - nic.tester.contents = view[header.IPv6MinimumSize:totalLen] + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv6Addr + nic.testObject.dstAddr = localIPv6Addr + nic.testObject.contents = view[header.IPv6MinimumSize:totalLen] r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) if err != nil { @@ -888,8 +884,8 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if nic.tester.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) + if nic.testObject.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } } @@ -931,11 +927,11 @@ func TestIPv6ReceiveControl(t *testing.T) { s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) nic := testInterface{ - tester: testObject{ + testObject: testObject{ t: t, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) + ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -994,19 +990,19 @@ func TestIPv6ReceiveControl(t *testing.T) { // Give packet to IPv6 endpoint, dispatcher will validate that // it's ok. - nic.tester.protocol = 10 - nic.tester.srcAddr = remoteIPv6Addr - nic.tester.dstAddr = localIPv6Addr - nic.tester.contents = view[dataOffset:] - nic.tester.typ = c.expectedTyp - nic.tester.extra = c.expectedExtra + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv6Addr + nic.testObject.dstAddr = localIPv6Addr + nic.testObject.contents = view[dataOffset:] + nic.testObject.typ = c.expectedTyp + nic.testObject.extra = c.expectedExtra // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize)) - if want := c.expectedCount; nic.tester.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want) + if want := c.expectedCount; nic.testObject.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) } }) } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 0a7e98ed1..7fc12e229 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -28,12 +28,15 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 5c4f715d7..3407755ed 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,6 +15,8 @@ package ipv4 import ( + "fmt" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -77,31 +79,29 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { received.Echo.Increment() // Only send a reply if the checksum is valid. - wantChecksum := h.Checksum() - // Reset the checksum field to 0 to can calculate the proper - // checksum. We'll have to reset this before we hand the packet - // off. + headerChecksum := h.Checksum() h.SetChecksum(0) - gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) - if gotChecksum != wantChecksum { - // It's possible that a raw socket expects to receive this. - h.SetChecksum(wantChecksum) + calculatedChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) + h.SetChecksum(headerChecksum) + if calculatedChecksum != headerChecksum { + // It's possible that a raw socket still expects to receive this. e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) received.Invalid.Increment() return } - // Make a copy of data before pkt gets sent to raw socket. - // DeliverTransportPacket will take ownership of pkt. - replyData := pkt.Data.Clone(nil) - replyData.TrimFront(header.ICMPv4MinimumSize) + // DeliverTransportPacket will take ownership of pkt so don't use it beyond + // this point. Make a deep copy of the data before pkt gets sent as we will + // be modifying fields. + // + // TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no + // waiting endpoints. Consider moving responsibility for doing the copy to + // DeliverTransportPacket so that is is only done when needed. + replyData := pkt.Data.ToOwnedView() + replyIPHdr := header.IPv4(append(buffer.View(nil), pkt.NetworkHeader().View()...)) - // It's possible that a raw socket expects to receive this. - h.SetChecksum(wantChecksum) e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) - remoteLinkAddr := r.RemoteLinkAddress - // As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP // source address MUST be one of its own IP addresses (but not a broadcast // or multicast address). @@ -117,32 +117,49 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { } defer r.Release() - // Use the remote link address from the incoming packet. - r.ResolveWith(remoteLinkAddr) - - // Prepare a reply packet. - icmpHdr := make(header.ICMPv4, header.ICMPv4MinimumSize) - copy(icmpHdr, h) - icmpHdr.SetType(header.ICMPv4EchoReply) - icmpHdr.SetChecksum(0) - icmpHdr.SetChecksum(^header.Checksum(icmpHdr, header.ChecksumVV(replyData, 0))) - dataVV := buffer.View(icmpHdr).ToVectorisedView() - dataVV.Append(replyData) + // TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the + // header information, we may have to change this code to handle the + // ICMP header no longer being in the data buffer. + + // Because IP and ICMP are so closely intertwined, we need to handcraft our + // IP header to be able to follow RFC 792. The wording on page 13 is as + // follows: + // IP Fields: + // Addresses + // The address of the source in an echo message will be the + // destination of the echo reply message. To form an echo reply + // message, the source and destination addresses are simply reversed, + // the type code changed to 0, and the checksum recomputed. + // + // This was interpreted by early implementors to mean that all options must + // be copied from the echo request IP header to the echo reply IP header + // and this behaviour is still relied upon by some applications. + // + // Create a copy of the IP header we received, options and all, and change + // The fields we need to alter. + // + // We need to produce the entire packet in the data segment in order to + // use WriteHeaderIncludedPacket(). + replyIPHdr.SetSourceAddress(r.LocalAddress) + replyIPHdr.SetDestinationAddress(r.RemoteAddress) + replyIPHdr.SetTTL(r.DefaultTTL()) + + replyICMPHdr := header.ICMPv4(replyData) + replyICMPHdr.SetType(header.ICMPv4EchoReply) + replyICMPHdr.SetChecksum(0) + replyICMPHdr.SetChecksum(^header.Checksum(replyData, 0)) + + replyVV := buffer.View(replyIPHdr).ToVectorisedView() + replyVV.AppendView(replyData) replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: dataVV, + Data: replyVV, }) - // TODO(gvisor.dev/issue/3810): When adding protocol numbers into the header - // information we will have to change this code to handle the ICMP header - // no longer being in the data buffer. replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber - // Send out the reply packet. + + // The checksum will be calculated so we don't need to do it here. sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ - Protocol: header.ICMPv4ProtocolNumber, - TTL: r.DefaultTTL(), - TOS: stack.DefaultTOS, - }, replyPkt); err != nil { + if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil { sent.Dropped.Increment() return } @@ -211,18 +228,18 @@ type icmpReasonPortUnreachable struct{} func (*icmpReasonPortUnreachable) isICMPReason() {} +// icmpReasonProtoUnreachable is an error where the transport protocol is +// not supported. +type icmpReasonProtoUnreachable struct{} + +func (*icmpReasonProtoUnreachable) isICMPReason() {} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv4 and sends it back to the remote device that sent // the problematic packet. It incorporates as much of that packet as // possible as well as any error metadata as is available. returnError // expects pkt to hold a valid IPv4 packet as per the wire format. -func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { - sent := r.Stats().ICMP.V4PacketsSent - if !r.Stack().AllowICMPMessage() { - sent.RateLimited.Increment() - return nil - } - +func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { // We check we are responding only when we are allowed to. // See RFC 1812 section 4.3.2.7 (shown below). // @@ -251,6 +268,25 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc return nil } + // Even if we were able to receive a packet from some remote, we may not have + // a route to it - the remote may be blocked via routing rules. We must always + // consult our routing table and find a route to the remote before sending any + // packet. + route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() + // From this point on, the incoming route should no longer be used; route + // must be used to send the ICMP error. + r = nil + + sent := p.stack.Stats().ICMP.V4PacketsSent + if !p.stack.AllowICMPMessage() { + sent.RateLimited.Increment() + return nil + } + networkHeader := pkt.NetworkHeader().View() transportHeader := pkt.TransportHeader().View() @@ -287,8 +323,6 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc // Assume any type we don't know about may be an error type. return nil } - } else if transportHeader.IsEmpty() { - return nil } // Now work out how much of the triggering packet we should return. @@ -303,11 +337,11 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc // least 8 bytes of the payload must be included. Today linux and other // systems implement the RFC 1812 definition and not the original // requirement. We treat 8 bytes as the minimum but will try send more. - mtu := int(r.MTU()) + mtu := int(route.MTU()) if mtu > header.IPv4MinimumProcessableDatagramSize { mtu = header.IPv4MinimumProcessableDatagramSize } - headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize + headerLen := int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize available := int(mtu) - headerLen if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize { @@ -336,19 +370,27 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc ReserveHeaderBytes: headerLen, Data: payload, }) + icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) + switch reason.(type) { + case *icmpReasonPortUnreachable: + icmpHdr.SetCode(header.ICMPv4PortUnreachable) + case *icmpReasonProtoUnreachable: + icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) + default: + panic(fmt.Sprintf("unsupported ICMP type %T", reason)) + } icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4PortUnreachable) - counter := sent.DstUnreachable icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) + counter := sent.DstUnreachable - if err := r.WritePacket( + if err := route.WritePacket( nil, /* gso */ stack.NetworkHeaderParams{ Protocol: header.ICMPv4ProtocolNumber, - TTL: r.DefaultTTL(), + TTL: route.DefaultTTL(), TOS: stack.DefaultTOS, }, icmpPkt, diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index ad7a767a4..c5ac7b8b5 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -18,6 +18,7 @@ package ipv4 import ( "fmt" "sync/atomic" + "time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -30,6 +31,15 @@ import ( ) const ( + // As per RFC 791 section 3.2: + // The current recommendation for the initial timer setting is 15 seconds. + // This may be changed as experience with this protocol accumulates. + // + // Considering that it is an old recommendation, we use the same reassembly + // timeout that linux defines, which is 30 seconds: + // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ip.h#L138 + reassembleTimeout = 30 * time.Second + // ProtocolNumber is the ipv4 protocol number. ProtocolNumber = header.IPv4ProtocolNumber @@ -56,7 +66,6 @@ var _ stack.NetworkEndpoint = (*endpoint)(nil) type endpoint struct { nic stack.NetworkInterface - linkEP stack.LinkEndpoint dispatcher stack.TransportDispatcher protocol *protocol @@ -77,7 +86,6 @@ type endpoint struct { func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ nic: nic, - linkEP: nic.LinkEndpoint(), dispatcher: dispatcher, protocol: p, } @@ -168,21 +176,13 @@ func (e *endpoint) DefaultTTL() uint8 { // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus // the network layer max header length. func (e *endpoint) MTU() uint32 { - return calculateMTU(e.linkEP.MTU()) + return calculateMTU(e.nic.MTU()) } // MaxHeaderLength returns the maximum length needed by ipv4 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { - return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize -} - -// GSOMaxSize returns the maximum GSO packet size. -func (e *endpoint) GSOMaxSize() uint32 { - if gso, ok := e.linkEP.(stack.GSOEndpoint); ok { - return gso.GSOMaxSize() - } - return 0 + return e.nic.MaxHeaderLength() + header.IPv4MaximumHeaderSize } // NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. @@ -190,98 +190,26 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } -// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to -// write. It assumes that the IP header is already present in pkt.NetworkHeader. -// pkt.TransportHeader may be set. mtu includes the IP header and options. This -// does not support the DontFragment IP flag. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error { - // This packet is too big, it needs to be fragmented. - ip := header.IPv4(pkt.NetworkHeader().View()) - flags := ip.Flags() - - // Update mtu to take into account the header, which will exist in all - // fragments anyway. - innerMTU := mtu - int(ip.HeaderLength()) - - // Round the MTU down to align to 8 bytes. Then calculate the number of - // fragments. Calculate fragment sizes as in RFC791. - innerMTU &^= 7 - n := (int(ip.PayloadLength()) + innerMTU - 1) / innerMTU - - outerMTU := innerMTU + int(ip.HeaderLength()) - offset := ip.FragmentOffset() - - // Keep the length reserved for link-layer, we need to create fragments with - // the same reserved length. - reservedForLink := pkt.AvailableHeaderBytes() - - // Destroy the packet, pull all payloads out for fragmentation. - transHeader, data := pkt.TransportHeader().View(), pkt.Data - - // Where possible, the first fragment that is sent has the same - // number of bytes reserved for header as the input packet. The link-layer - // endpoint may depend on this for looking at, eg, L4 headers. - transFitsFirst := len(transHeader) <= innerMTU - - for i := 0; i < n; i++ { - reserve := reservedForLink + int(ip.HeaderLength()) - if i == 0 && transFitsFirst { - // Reserve for transport header if it's going to be put in the first - // fragment. - reserve += len(transHeader) - } - fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: reserve, - }) - fragPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - - // Copy data for the fragment. - avail := innerMTU - - if n := len(transHeader); n > 0 { - if n > avail { - n = avail - } - if i == 0 && transFitsFirst { - copy(fragPkt.TransportHeader().Push(n), transHeader) - } else { - fragPkt.Data.AppendView(transHeader[:n:n]) - } - transHeader = transHeader[n:] - avail -= n - } - - if avail > 0 { - n := data.Size() - if n > avail { - n = avail - } - data.ReadToVV(&fragPkt.Data, n) - avail -= n - } - - copied := uint16(innerMTU - avail) - - // Set lengths in header and calculate checksum. - h := header.IPv4(fragPkt.NetworkHeader().Push(len(ip))) - copy(h, ip) - if i != n-1 { - h.SetTotalLength(uint16(outerMTU)) - h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset) - } else { - h.SetTotalLength(uint16(h.HeaderLength()) + copied) - h.SetFlagsFragmentOffset(flags, offset) - } - h.SetChecksum(0) - h.SetChecksum(^h.CalculateChecksum()) - offset += copied +// writePacketFragments fragments pkt and writes the results on the link +// endpoint. The IP header must already present in the original packet. The mtu +// is the maximum size of the packets. +func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer) *tcpip.Error { + networkHeader := header.IPv4(pkt.NetworkHeader().View()) + fragMTU := int(calculateFragmentInnerMTU(mtu, pkt)) + pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader)) - // Send out the fragment. - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil { + for { + fragPkt, more := buildNextFragment(&pf, networkHeader) + if err := e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pf.RemainingFragmentCount() + 1)) return err } r.Stats().IP.PacketsSent.Increment() + if !more { + break + } } + return nil } @@ -303,7 +231,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s DstAddr: r.RemoteAddress, }) ip.SetChecksum(^ip.CalculateChecksum()) - pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + pkt.NetworkProtocolNumber = ProtocolNumber } // WritePacket writes a packet to the given destination address and protocol. @@ -329,7 +257,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv4(pkt.NetworkHeader().View()) - ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) + ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()) if err == nil { route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) ep.HandlePacket(&route, pkt) @@ -345,10 +273,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw if r.Loop&stack.PacketOut == 0 { return nil } - if pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { - return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt) + if pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) { + return e.writePacketFragments(r, gso, e.nic.MTU(), pkt) } - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() return err } r.Stats().IP.PacketsSent.Increment() @@ -377,8 +306,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. - n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + } return n, err } r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) @@ -392,7 +324,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if _, ok := natPkts[pkt]; ok { netHeader := header.IPv4(pkt.NetworkHeader().View()) - if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil { + if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { src := netHeader.SourceAddress() dst := netHeader.DestinationAddress() route := r.ReverseRoute(src, dst) @@ -401,8 +333,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe continue } } - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped))) // Dropped packets aren't errors, so include them in // the return value. return n + len(dropped), err @@ -461,9 +394,12 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return nil } + if err := e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } r.Stats().IP.PacketsSent.Increment() - - return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) + return nil } // HandlePacket is called by the link layer when new ipv4 packets arrive for @@ -566,7 +502,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // 3 (Port Unreachable), when the designated transport protocol // (e.g., UDP) is unable to demultiplex the datagram but has no // protocol mechanism to inform the sender. - _ = returnError(r, &icmpReasonPortUnreachable{}, pkt) + _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt) + case stack.TransportPacketProtocolUnreachable: + // As per RFC: 1122 Section 3.2.2.1 + // A host SHOULD generate Destination Unreachable messages with code: + // 2 (Protocol Unreachable), when the designated transport protocol + // is not supported + _ = e.protocol.returnError(r, &icmpReasonProtoUnreachable{}, pkt) default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) } @@ -794,14 +736,36 @@ func calculateMTU(mtu uint32) uint32 { return mtu - header.IPv4MinimumSize } +// calculateFragmentInnerMTU calculates the maximum number of bytes of +// fragmentable data a fragment can have, based on the link layer mtu and pkt's +// network header size. +func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 { + if mtu > MaxTotalSize { + mtu = MaxTotalSize + } + mtu -= uint32(pkt.NetworkHeader().View().Size()) + // Round the MTU down to align to 8 bytes. + mtu &^= 7 + return mtu +} + +// addressToUint32 translates an IPv4 address into its little endian uint32 +// representation. +// +// This function does the same thing as binary.LittleEndian.Uint32 but operates +// on a tcpip.Address (a string) without the need to convert it to a byte slice, +// which would cause an allocation. +func addressToUint32(addr tcpip.Address) uint32 { + _ = addr[3] // bounds check hint to compiler + return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24 +} + // hashRoute calculates a hash value for the given route. It uses the source & -// destination address, the transport protocol number, and a random initial -// value (generated once on initialization) to generate the hash. +// destination address, the transport protocol number and a 32-bit number to +// generate the hash. func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { - t := r.LocalAddress - a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 - t = r.RemoteAddress - b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + a := addressToUint32(r.LocalAddress) + b := addressToUint32(r.RemoteAddress) return hash.Hash3Words(a, b, uint32(protocol), hashIV) } @@ -821,6 +785,29 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol { ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL, - fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout, s.Clock()), + fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()), + } +} + +func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) { + fragPkt, offset, copied, more := pf.BuildNextFragment() + fragPkt.NetworkProtocolNumber = ProtocolNumber + + originalIPHeaderLength := len(originalIPHeader) + nextFragIPHeader := header.IPv4(fragPkt.NetworkHeader().Push(originalIPHeaderLength)) + + if copied := copy(nextFragIPHeader, originalIPHeader); copied != len(originalIPHeader) { + panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got = %d, want = %d", copied, originalIPHeaderLength)) } + + flags := originalIPHeader.Flags() + if more { + flags |= header.IPv4FlagMoreFragments + } + nextFragIPHeader.SetFlagsFragmentOffset(flags, uint16(offset)) + nextFragIPHeader.SetTotalLength(uint16(nextFragIPHeader.HeaderLength()) + uint16(copied)) + nextFragIPHeader.SetChecksum(0) + nextFragIPHeader.SetChecksum(^nextFragIPHeader.CalculateChecksum()) + + return fragPkt, more } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 277560e35..9916d783f 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -16,19 +16,24 @@ package ipv4_test import ( "bytes" + "context" "encoding/hex" "math" + "net" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -92,6 +97,276 @@ func TestExcludeBroadcast(t *testing.T) { }) } +// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and +// checks the response. +func TestIPv4Sanity(t *testing.T) { + const ( + defaultMTU = header.IPv6MinimumMTU + ttl = 255 + nicID = 1 + randomSequence = 123 + randomIdent = 42 + ) + var ( + ipv4Addr = tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), + PrefixLen: 24, + } + remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4()) + ) + + tests := []struct { + name string + headerLength uint8 // value of 0 means "use correct size" + maxTotalLength uint16 + transportProtocol uint8 + TTL uint8 + shouldFail bool + expectICMP bool + ICMPType header.ICMPv4Type + ICMPCode header.ICMPv4Code + options []byte + }{ + { + name: "valid", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + }, + // The TTL tests check that we are not rejecting an incoming packet + // with a zero or one TTL, which has been a point of confusion in the + // past as RFC 791 says: "If this field contains the value zero, then the + // datagram must be destroyed". However RFC 1122 section 3.2.1.7 clarifies + // for the case of the destination host, stating as follows. + // + // A host MUST NOT send a datagram with a Time-to-Live (TTL) + // value of zero. + // + // A host MUST NOT discard a datagram just because it was + // received with TTL less than 2. + { + name: "zero TTL", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: 0, + shouldFail: false, + }, + { + name: "one TTL", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: 1, + shouldFail: false, + }, + { + name: "End options", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{0, 0, 0, 0}, + }, + { + name: "NOP options", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{1, 1, 1, 1}, + }, + { + name: "NOP and End options", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{1, 1, 0, 0}, + }, + { + name: "bad header length", + headerLength: header.IPv4MinimumSize - 1, + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + expectICMP: false, + }, + { + name: "bad total length (0)", + maxTotalLength: 0, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + expectICMP: false, + }, + { + name: "bad total length (ip - 1)", + maxTotalLength: uint16(header.IPv4MinimumSize - 1), + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + expectICMP: false, + }, + { + name: "bad total length (ip + icmp - 1)", + maxTotalLength: uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize - 1), + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + expectICMP: false, + }, + { + name: "bad protocol", + maxTotalLength: defaultMTU, + transportProtocol: 99, + TTL: ttl, + shouldFail: true, + expectICMP: true, + ICMPType: header.ICMPv4DstUnreachable, + ICMPCode: header.ICMPv4ProtoUnreachable, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + }) + // We expect at most a single packet in response to our ICMP Echo Request. + e := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) + } + + // Default routes for IPv4 so ICMP can find a route to the remote + // node when attempting to send the ICMP Echo Reply. + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + }) + + // Round up the header size to the next multiple of 4 as RFC 791, page 11 + // says: "Internet Header Length is the length of the internet header + // in 32 bit words..." and on page 23: "The internet header padding is + // used to ensure that the internet header ends on a 32 bit boundary." + ipHeaderLength := ((header.IPv4MinimumSize + len(test.options)) + header.IPv4IHLStride - 1) & ^(header.IPv4IHLStride - 1) + + if ipHeaderLength > header.IPv4MaximumHeaderSize { + t.Fatalf("too many bytes in options: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) + } + totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) + hdr := buffer.NewPrependable(int(totalLen)) + icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + + // Specify ident/seq to make sure we get the same in the response. + icmp.SetIdent(randomIdent) + icmp.SetSequence(randomSequence) + icmp.SetType(header.ICMPv4Echo) + icmp.SetCode(header.ICMPv4UnusedCode) + icmp.SetChecksum(0) + icmp.SetChecksum(^header.Checksum(icmp, 0)) + ip := header.IPv4(hdr.Prepend(ipHeaderLength)) + if test.maxTotalLength < totalLen { + totalLen = test.maxTotalLength + } + ip.Encode(&header.IPv4Fields{ + IHL: uint8(ipHeaderLength), + TotalLength: totalLen, + Protocol: test.transportProtocol, + TTL: test.TTL, + SrcAddr: remoteIPv4Addr, + DstAddr: ipv4Addr.Address, + }) + if n := copy(ip.Options(), test.options); n != len(test.options) { + t.Fatalf("options larger than available space: copied %d/%d bytes", n, len(test.options)) + } + // Override the correct value if the test case specified one. + if test.headerLength != 0 { + ip.SetHeaderLength(test.headerLength) + } + requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) + e.InjectInbound(header.IPv4ProtocolNumber, requestPkt) + reply, ok := e.Read() + if !ok { + if test.shouldFail { + if test.expectICMP { + t.Fatal("expected ICMP error response missing") + } + return // Expected silent failure. + } + t.Fatal("expected ICMP echo reply missing") + } + + // Check the route that brought the packet to us. + if reply.Route.LocalAddress != ipv4Addr.Address { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address) + } + if reply.Route.RemoteAddress != remoteIPv4Addr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr) + } + + // Make sure it's all in one buffer. + vv := buffer.NewVectorisedView(reply.Pkt.Size(), reply.Pkt.Views()) + replyIPHeader := header.IPv4(vv.ToView()) + + // At this stage we only know it's an IP header so verify that much. + checker.IPv4(t, replyIPHeader, + checker.SrcAddr(ipv4Addr.Address), + checker.DstAddr(remoteIPv4Addr), + ) + + // All expected responses are ICMP packets. + if got, want := replyIPHeader.Protocol(), uint8(header.ICMPv4ProtocolNumber); got != want { + t.Fatalf("not ICMP response, got protocol %d, want = %d", got, want) + } + replyICMPHeader := header.ICMPv4(replyIPHeader.Payload()) + + // Sanity check the response. + switch replyICMPHeader.Type() { + case header.ICMPv4DstUnreachable: + checker.IPv4(t, replyIPHeader, + checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.ICMPv4( + checker.ICMPv4Code(test.ICMPCode), + checker.ICMPv4Checksum(), + checker.ICMPv4Payload([]byte(hdr.View())), + ), + ) + if !test.shouldFail || !test.expectICMP { + t.Fatalf("unexpected packet rejection, got ICMP error packet type %d, code %d", + header.ICMPv4DstUnreachable, replyICMPHeader.Code()) + } + return + case header.ICMPv4EchoReply: + checker.IPv4(t, replyIPHeader, + checker.IPv4HeaderLength(ipHeaderLength), + checker.IPv4Options(test.options), + checker.IPFullLength(uint16(requestPkt.Size())), + checker.ICMPv4( + checker.ICMPv4Code(header.ICMPv4UnusedCode), + checker.ICMPv4Seq(randomSequence), + checker.ICMPv4Ident(randomIdent), + checker.ICMPv4Checksum(), + ), + ) + if test.shouldFail { + t.Fatalf("unexpected Echo Reply packet\n") + } + default: + t.Fatalf("unexpected ICMP response, got type %d, want = %d or %d", + replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable) + } + }) + } +} + // comparePayloads compared the contents of all the packets against the contents // of the source packet. func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) { @@ -123,16 +398,6 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI if got, want := len(ip), int(mtu); got > want { t.Errorf("fragment is too large, got %d want %d", got, want) } - if i == 0 { - got := packet.NetworkHeader().View().Size() + packet.TransportHeader().View().Size() - // sourcePacketInfo does not have NetworkHeader added, simulate one. - want := header.IPv4MinimumSize + sourcePacketInfo.TransportHeader().View().Size() - // Check that it kept the transport header in packet.TransportHeader if - // it fits in the first fragment. - if want < int(mtu) && got != want { - t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) - } - } if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want { t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) } @@ -162,6 +427,8 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI } func TestFragmentation(t *testing.T) { + const ttl = 42 + var manyPayloadViewsSizes [1000]int for i := range manyPayloadViewsSizes { manyPayloadViewsSizes[i] = 7 @@ -175,15 +442,15 @@ func TestFragmentation(t *testing.T) { payloadViewsSizes []int expectedFrags int }{ - {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1}, - {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1}, + {"No fragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1}, + {"No fragmentation with big header", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1}, {"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2}, - {"FragmentedWithGsoNil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2}, - {"FragmentedWithManyViews", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25}, - {"FragmentedWithManyViewsAndPrependableBytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25}, - {"FragmentedWithBigHeader", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2}, - {"FragmentedWithBigHeaderAndPrependableBytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2}, - {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6}, + {"Fragmented with gso nil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2}, + {"Fragmented with many views", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25}, + {"Fragmented with many views and prependable bytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25}, + {"Fragmented with big header", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2}, + {"Fragmented with big header and prependable bytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2}, + {"Fragmented with MTU smaller than header and prependable bytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6}, } for _, ft := range fragTests { @@ -194,11 +461,11 @@ func TestFragmentation(t *testing.T) { source := pkt.Clone() err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, - TTL: 42, + TTL: ttl, TOS: stack.DefaultTOS, }, pkt) if err != nil { - t.Errorf("got err = %s, want = nil", err) + t.Fatalf("r.WritePacket(_, _, _) = %s", err) } if got := len(ep.WrittenPackets); got != ft.expectedFrags { @@ -207,6 +474,9 @@ func TestFragmentation(t *testing.T) { if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want { t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want) } + if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) + } compareFragments(t, ep.WrittenPackets, source, ft.mtu) }) } @@ -215,36 +485,70 @@ func TestFragmentation(t *testing.T) { // TestFragmentationErrors checks that errors are returned from write packet // correctly. func TestFragmentationErrors(t *testing.T) { + const ttl = 42 + + expectedError := tcpip.ErrAborted fragTests := []struct { description string mtu uint32 transportHeaderLength int - payloadViewsSizes []int - err *tcpip.Error + payloadSize int allowPackets int + fragmentCount int }{ - {"NoFrag", 2000, 0, []int{1000}, tcpip.ErrAborted, 0}, - {"ErrorOnFirstFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 0}, - {"ErrorOnSecondFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 1}, - {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, tcpip.ErrAborted, 0}, + { + description: "No frag", + mtu: 2000, + transportHeaderLength: 0, + payloadSize: 1000, + allowPackets: 0, + fragmentCount: 1, + }, + { + description: "Error on first frag", + mtu: 500, + transportHeaderLength: 0, + payloadSize: 1000, + allowPackets: 0, + fragmentCount: 3, + }, + { + description: "Error on second frag", + mtu: 500, + transportHeaderLength: 0, + payloadSize: 1000, + allowPackets: 1, + fragmentCount: 3, + }, + { + description: "Error on first frag MTU smaller than header", + mtu: 500, + transportHeaderLength: 1000, + payloadSize: 500, + allowPackets: 0, + fragmentCount: 4, + }, } for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.err, ft.allowPackets) + ep := testutil.NewMockLinkEndpoint(ft.mtu, expectedError, ft.allowPackets) r := buildRoute(t, ep) - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber) + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, - TTL: 42, + TTL: ttl, TOS: stack.DefaultTOS, }, pkt) - if err != ft.err { - t.Errorf("got WritePacket() = %s, want = %s", err, ft.err) + if err != expectedError { + t.Errorf("got WritePacket() = %s, want = %s", err, expectedError) } if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want { t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want) } + if got, want := int(r.Stats().IP.OutgoingPacketErrors.Value()), ft.fragmentCount-ft.allowPackets; got != want { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, want) + } }) } } @@ -1005,6 +1309,7 @@ func TestReceiveFragments(t *testing.T) { func TestWriteStats(t *testing.T) { const nPackets = 3 + tests := []struct { name string setup func(*testing.T, *stack.Stack) @@ -1150,12 +1455,13 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { dst = "\x10\x00\x00\x02" ) if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(1, %d, _) failed: %s", ipv4.ProtocolNumber, err) + t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err) } { - subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) + mask := tcpip.AddressMask(header.IPv4Broadcast) + subnet, err := tcpip.NewSubnet(dst, mask) if err != nil { - t.Fatalf("NewSubnet(_, _) failed: %v", err) + t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err) } s.SetRouteTable([]tcpip.Route{{ Destination: subnet, @@ -1164,7 +1470,7 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { } rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ipv4.ProtocolNumber, err) + t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s", src, dst, ipv4.ProtocolNumber, err) } return rt } @@ -1188,3 +1494,204 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, lm.limit-- return false, false } + +func TestPacketQueing(t *testing.T) { + const nicID = 1 + + var ( + host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") + + host1IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), + PrefixLen: 24, + }, + } + host2IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), + PrefixLen: 8, + }, + } + ) + + tests := []struct { + name string + rxPkt func(*channel.Endpoint) + checkResp func(*testing.T, *channel.Endpoint) + }{ + { + name: "ICMP Error", + rxPkt: func(e *channel.Endpoint) { + hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.UDPMinimumSize) + u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: header.UDPMinimumSize, + }) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize) + sum = header.Checksum(header.UDP([]byte{}), sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: header.IPv4MinimumSize + header.UDPMinimumSize, + TTL: ipv4.DefaultTTL, + Protocol: uint8(udp.ProtocolNumber), + SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, + DstAddr: host1IPv4Addr.AddressWithPrefix.Address, + }) + e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + }, + checkResp: func(t *testing.T, e *channel.Endpoint) { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != header.IPv4ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) + } + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + } + checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4PortUnreachable))) + }, + }, + + { + name: "Ping", + rxPkt: func(e *channel.Endpoint) { + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + pkt.SetType(header.ICMPv4Echo) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(^header.Checksum(pkt, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + Protocol: uint8(icmp.ProtocolNumber4), + TTL: ipv4.DefaultTTL, + SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, + DstAddr: host1IPv4Addr.AddressWithPrefix.Address, + }) + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + }, + checkResp: func(t *testing.T, e *channel.Endpoint) { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != header.IPv4ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) + } + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + } + checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4EchoReply), + checker.ICMPv4Code(header.ICMPv4UnusedCode))) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, arp.ProtocolNumber, arp.ProtocolAddress, err) + } + if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err) + } + + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: nicID, + }, + }) + + // Receive a packet to trigger link resolution before a response is sent. + test.rxPkt(e) + + // Wait for a ARP request since link address resolution should be + // performed. + { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != arp.ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber) + } + if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) + } + rep := header.ARP(p.Pkt.NetworkHeader().View()) + if got := rep.Op(); got != header.ARPRequest { + t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest) + } + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != host1NICLinkAddr { + t.Errorf("got HardwareAddressSender = %s, want = %s", got, host1NICLinkAddr) + } + if got := tcpip.Address(rep.ProtocolAddressSender()); got != host1IPv4Addr.AddressWithPrefix.Address { + t.Errorf("got ProtocolAddressSender = %s, want = %s", got, host1IPv4Addr.AddressWithPrefix.Address) + } + if got := tcpip.Address(rep.ProtocolAddressTarget()); got != host2IPv4Addr.AddressWithPrefix.Address { + t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, host2IPv4Addr.AddressWithPrefix.Address) + } + } + + // Send an ARP reply to complete link address resolution. + { + hdr := buffer.View(make([]byte, header.ARPSize)) + packet := header.ARP(hdr) + packet.SetIPv4OverEthernet() + packet.SetOp(header.ARPReply) + copy(packet.HardwareAddressSender(), host2NICLinkAddr) + copy(packet.ProtocolAddressSender(), host2IPv4Addr.AddressWithPrefix.Address) + copy(packet.HardwareAddressTarget(), host1NICLinkAddr) + copy(packet.ProtocolAddressTarget(), host1IPv4Addr.AddressWithPrefix.Address) + e.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.ToVectorisedView(), + })) + } + + // Expect the response now that the link address has resolved. + test.checkResp(t, e) + + // Since link resolution was already performed, it shouldn't be performed + // again. + test.rxPkt(e) + test.checkResp(t, e) + }) + } +} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index 97adbcbd4..a30437f02 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -18,6 +18,7 @@ go_library( "//pkg/tcpip/header", "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", + "//pkg/tcpip/network/hash", "//pkg/tcpip/stack", ], ) @@ -41,6 +42,7 @@ go_test( "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 4b4b483cc..7be35c78b 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -286,6 +286,17 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr) } + // As per RFC 4861 section 7.1.1: + // A node MUST silently discard any received Neighbor Solicitation + // messages that do not satisfy all of the following validity checks: + // ... + // - If the IP source address is the unspecified address, the IP + // destination address is a solicited-node multicast address. + if unspecifiedSource && !header.IsSolicitedNodeAddr(r.LocalAddress) { + received.Invalid.Increment() + return + } + // ICMPv6 Neighbor Solicit messages are always sent to // specially crafted IPv6 multicast addresses. As a result, the // route we end up with here has as its LocalAddress such a @@ -429,8 +440,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } - remoteLinkAddr := r.RemoteLinkAddress - // As per RFC 4291 section 2.7, multicast addresses must not be used as // source addresses in IPv6 packets. localAddr := r.LocalAddress @@ -445,9 +454,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } defer r.Release() - // Use the link address from the source of the original packet. - r.ResolveWith(remoteLinkAddr) - replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize, Data: pkt.Data, @@ -635,18 +641,21 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { // LinkAddressRequest implements stack.LinkAddressResolver. func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { - snaddr := header.SolicitedNodeAddr(addr) - // TODO(b/148672031): Use stack.FindRoute instead of manually creating the // route here. Note, we would need the nicID to do this properly so the right // NIC (associated to linkEP) is used to send the NDP NS message. - r := &stack.Route{ + r := stack.Route{ LocalAddress: localAddr, - RemoteAddress: snaddr, + RemoteAddress: addr, RemoteLinkAddress: remoteLinkAddr, } + + // If a remote address is not already known, then send a multicast + // solicitation since multicast addresses have a static mapping to link + // addresses. if len(r.RemoteLinkAddress) == 0 { - r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(snaddr) + r.RemoteAddress = header.SolicitedNodeAddr(addr) + r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(r.RemoteAddress) } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -672,7 +681,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd }) // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) + return linkEP.WritePacket(&r, nil /* gso */, ProtocolNumber, pkt) } // ResolveStaticAddress implements stack.LinkAddressResolver. @@ -690,6 +699,36 @@ type icmpReason interface { isICMPReason() } +// icmpReasonParameterProblem is an error during processing of extension headers +// or the fixed header defined in RFC 4443 section 3.4. +type icmpReasonParameterProblem struct { + code header.ICMPv6Code + + // respondToMulticast indicates that we are sending a packet that falls under + // the exception outlined by RFC 4443 section 2.4 point e.3 exception 2: + // + // (e.3) A packet destined to an IPv6 multicast address. (There are + // two exceptions to this rule: (1) the Packet Too Big Message + // (Section 3.2) to allow Path MTU discovery to work for IPv6 + // multicast, and (2) the Parameter Problem Message, Code 2 + // (Section 3.4) reporting an unrecognized IPv6 option (see + // Section 4.2 of [IPv6]) that has the Option Type highest- + // order two bits set to 10). + respondToMulticast bool + + // pointer is defined in the RFC 4443 setion 3.4 which reads: + // + // Pointer Identifies the octet offset within the invoking packet + // where the error was detected. + // + // The pointer will point beyond the end of the ICMPv6 + // packet if the field in error is beyond what can fit + // in the maximum size of an ICMPv6 error message. + pointer uint32 +} + +func (*icmpReasonParameterProblem) isICMPReason() {} + // icmpReasonPortUnreachable is an error where the transport protocol has no // listener and no alternative means to inform the sender. type icmpReasonPortUnreachable struct{} @@ -698,18 +737,11 @@ func (*icmpReasonPortUnreachable) isICMPReason() {} // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv6 and sends it. -func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { - stats := r.Stats().ICMP - sent := stats.V6PacketsSent - if !r.Stack().AllowICMPMessage() { - sent.RateLimited.Increment() - return nil - } - +func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { // Only send ICMP error if the address is not a multicast v6 // address and the source is not the unspecified address. // - // TODO(b/164522993) There are exceptions to this rule. + // There are exceptions to this rule. // See: point e.3) RFC 4443 section-2.4 // // (e) An ICMPv6 error message MUST NOT be originated as a result of @@ -727,7 +759,32 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc // Section 4.2 of [IPv6]) that has the Option Type highest- // order two bits set to 10). // - if header.IsV6MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv6Any { + var allowResponseToMulticast bool + if reason, ok := reason.(*icmpReasonParameterProblem); ok { + allowResponseToMulticast = reason.respondToMulticast + } + + if (!allowResponseToMulticast && header.IsV6MulticastAddress(r.LocalAddress)) || r.RemoteAddress == header.IPv6Any { + return nil + } + + // Even if we were able to receive a packet from some remote, we may not have + // a route to it - the remote may be blocked via routing rules. We must always + // consult our routing table and find a route to the remote before sending any + // packet. + route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() + // From this point on, the incoming route should no longer be used; route + // must be used to send the ICMP error. + r = nil + + stats := p.stack.Stats().ICMP + sent := stats.V6PacketsSent + if !p.stack.AllowICMPMessage() { + sent.RateLimited.Increment() return nil } @@ -757,11 +814,11 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc // packet that caused the error) as possible without making // the error message packet exceed the minimum IPv6 MTU // [IPv6]. - mtu := int(r.MTU()) + mtu := int(route.MTU()) if mtu > header.IPv6MinimumMTU { mtu = header.IPv6MinimumMTU } - headerLen := int(r.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize + headerLen := int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize available := int(mtu) - headerLen if available < header.IPv6MinimumSize { return nil @@ -780,12 +837,30 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) - icmpHdr.SetCode(header.ICMPv6PortUnreachable) - icmpHdr.SetType(header.ICMPv6DstUnreachable) - icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, newPkt.Data)) - counter := sent.DstUnreachable - err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, newPkt) - if err != nil { + var counter *tcpip.StatCounter + switch reason := reason.(type) { + case *icmpReasonParameterProblem: + icmpHdr.SetType(header.ICMPv6ParamProblem) + icmpHdr.SetCode(reason.code) + icmpHdr.SetTypeSpecific(reason.pointer) + counter = sent.ParamProblem + case *icmpReasonPortUnreachable: + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6PortUnreachable) + counter = sent.DstUnreachable + default: + panic(fmt.Sprintf("unsupported ICMP type %T", reason)) + } + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, route.LocalAddress, route.RemoteAddress, newPkt.Data)) + if err := route.WritePacket( + nil, /* gso */ + stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: route.DefaultTTL(), + TOS: stack.DefaultTOS, + }, + newPkt, + ); err != nil { sent.Dropped.Increment() return err } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 31370c1d4..3affcc4e4 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -16,17 +16,21 @@ package ipv6 import ( "context" + "net" "reflect" "strings" "testing" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -39,6 +43,9 @@ const ( defaultChannelSize = 1 defaultMTU = 65536 + + // Extra time to use when waiting for an async event to occur. + defaultAsyncPositiveEventTimeout = 30 * time.Second ) var ( @@ -50,6 +57,10 @@ type stubLinkEndpoint struct { stack.LinkEndpoint } +func (*stubLinkEndpoint) MTU() uint32 { + return defaultMTU +} + func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { // Indicate that resolution for link layer addresses is required to send // packets over this link. This is needed so the NIC knows to allocate a @@ -105,7 +116,9 @@ func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) { var _ stack.NetworkInterface = (*testInterface)(nil) -type testInterface struct{} +type testInterface struct { + stack.NetworkLinkEndpoint +} func (*testInterface) ID() tcpip.NICID { return 0 @@ -123,10 +136,6 @@ func (*testInterface) Enabled() bool { return true } -func (*testInterface) LinkEndpoint() stack.LinkEndpoint { - return nil -} - func TestICMPCounts(t *testing.T) { tests := []struct { name string @@ -1219,19 +1228,22 @@ func TestLinkAddressRequest(t *testing.T) { mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr) tests := []struct { - name string - remoteLinkAddr tcpip.LinkAddress - expectLinkAddr tcpip.LinkAddress + name string + remoteLinkAddr tcpip.LinkAddress + expectedLinkAddr tcpip.LinkAddress + expectedAddr tcpip.Address }{ { - name: "Unicast", - remoteLinkAddr: linkAddr1, - expectLinkAddr: linkAddr1, + name: "Unicast", + remoteLinkAddr: linkAddr1, + expectedLinkAddr: linkAddr1, + expectedAddr: lladdr0, }, { - name: "Multicast", - remoteLinkAddr: "", - expectLinkAddr: mcaddr, + name: "Multicast", + remoteLinkAddr: "", + expectedLinkAddr: mcaddr, + expectedAddr: snaddr, }, } @@ -1254,9 +1266,229 @@ func TestLinkAddressRequest(t *testing.T) { if !ok { t.Fatal("expected to send a link address request") } + if pkt.Route.RemoteLinkAddress != test.expectedLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedLinkAddr) + } + if pkt.Route.RemoteAddress != test.expectedAddr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedAddr) + } + if pkt.Route.LocalAddress != lladdr1 { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1) + } + checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), + checker.SrcAddr(lladdr1), + checker.DstAddr(test.expectedAddr), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(lladdr0), + checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}), + )) + } +} + +func TestPacketQueing(t *testing.T) { + const nicID = 1 + + var ( + host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") - if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want) + host1IPv6Addr = tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("a::1").To16()), + PrefixLen: 64, + }, + } + host2IPv6Addr = tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("a::2").To16()), + PrefixLen: 64, + }, } + ) + + tests := []struct { + name string + rxPkt func(*channel.Endpoint) + checkResp func(*testing.T, *channel.Endpoint) + }{ + { + name: "ICMP Error", + rxPkt: func(e *channel.Endpoint) { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) + u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: header.UDPMinimumSize, + }) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize) + sum = header.Checksum(header.UDP([]byte{}), sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: DefaultTTL, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + }) + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + }, + checkResp: func(t *testing.T, e *channel.Endpoint) { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) + } + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + checker.ICMPv6Code(header.ICMPv6PortUnreachable))) + }, + }, + + { + name: "Ping", + rxPkt: func(e *channel.Endpoint) { + totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + pkt.SetType(header.ICMPv6EchoRequest) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: header.ICMPv6MinimumSize, + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: DefaultTTL, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + }) + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + }, + checkResp: func(t *testing.T, e *channel.Endpoint) { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) + } + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoReply), + checker.ICMPv6Code(header.ICMPv6UnusedCode))) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + + e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err) + } + + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: nicID, + }, + }) + + // Receive a packet to trigger link resolution before a response is sent. + test.rxPkt(e) + + // Wait for a neighbor solicitation since link address resolution should + // be performed. + { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != ProtocolNumber { + t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber) + } + snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address) + if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) + } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(host2IPv6Addr.AddressWithPrefix.Address), + checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(host1NICLinkAddr)}), + )) + } + + // Send a neighbor advertisement to complete link address resolution. + { + naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) + pkt := header.ICMPv6(hdr.Prepend(naSize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(true) + na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address) + na.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(host2NICLinkAddr), + }) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: header.NDPHopLimit, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + }) + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + // Expect the response now that the link address has resolved. + test.checkResp(t, e) + + // Since link resolution was already performed, it shouldn't be performed + // again. + test.rxPkt(e) + test.checkResp(t, e) + }) } } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index aff4e1425..2bd8f4ece 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,9 +16,12 @@ package ipv6 import ( + "encoding/binary" "fmt" + "hash/fnv" "sort" "sync/atomic" + "time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -26,10 +29,20 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" + "gvisor.dev/gvisor/pkg/tcpip/network/hash" "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( + // As per RFC 8200 section 4.5: + // If insufficient fragments are received to complete reassembly of a packet + // within 60 seconds of the reception of the first-arriving fragment of that + // packet, reassembly of that packet must be abandoned. + // + // Linux also uses 60 seconds for reassembly timeout: + // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ipv6.h#L456 + reassembleTimeout = 60 * time.Second + // ProtocolNumber is the ipv6 protocol number. ProtocolNumber = header.IPv6ProtocolNumber @@ -40,6 +53,9 @@ const ( // DefaultTTL is the default hop limit for IPv6 Packets egressed by // Netstack. DefaultTTL = 64 + + // buckets for fragment identifiers + buckets = 2048 ) var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) @@ -50,7 +66,6 @@ var _ NDPEndpoint = (*endpoint)(nil) type endpoint struct { nic stack.NetworkInterface - linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache nud stack.NUDHandler dispatcher stack.TransportDispatcher @@ -348,21 +363,13 @@ func (e *endpoint) DefaultTTL() uint8 { // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus // the network layer max header length. func (e *endpoint) MTU() uint32 { - return calculateMTU(e.linkEP.MTU()) + return calculateMTU(e.nic.MTU()) } // MaxHeaderLength returns the maximum length needed by ipv6 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { - return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize -} - -// GSOMaxSize returns the maximum GSO packet size. -func (e *endpoint) GSOMaxSize() uint32 { - if gso, ok := e.linkEP.(stack.GSOEndpoint); ok { - return gso.GSOMaxSize() - } - return 0 + return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { @@ -376,7 +383,44 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber + pkt.NetworkProtocolNumber = ProtocolNumber +} + +func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool { + return pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) +} + +// handleFragments fragments pkt and calls the handler function on each +// fragment. It returns the number of fragments handled and the number of +// fragments left to be processed. The IP header must already be present in the +// original packet. The mtu is the maximum size of the packets. The transport +// header protocol number is required to avoid parsing the IPv6 extension +// headers. +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { + fragMTU := int(calculateFragmentInnerMTU(mtu, pkt)) + if fragMTU < pkt.TransportHeader().View().Size() { + // As per RFC 8200 Section 4.5, the Transport Header is expected to be small + // enough to fit in the first fragment. + return 0, 1, tcpip.ErrMessageTooLong + } + + pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, calculateFragmentReserve(pkt)) + id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, e.protocol.hashIV)%buckets], 1) + networkHeader := header.IPv6(pkt.NetworkHeader().View()) + + var n int + for { + fragPkt, more := buildNextFragment(&pf, networkHeader, transProto, id) + if err := handler(fragPkt); err != nil { + return n, pf.RemainingFragmentCount() + 1, err + } + n++ + if !more { + break + } + } + + return n, 0, nil } // WritePacket writes a packet to the given destination address and protocol. @@ -402,7 +446,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) - if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil { + if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) ep.HandlePacket(&route, pkt) return nil @@ -423,14 +467,29 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return nil } - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if e.packetMustBeFragmented(pkt, gso) { + sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each + // fragment one by one using WritePacket() (current strategy) or if we + // want to create a PacketBufferList from the fragments and feed it to + // WritePackets(). It'll be faster but cost more memory. + return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt) + }) + r.Stats().IP.PacketsSent.IncrementBy(uint64(sent)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain)) return err } + + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } + r.Stats().IP.PacketsSent.Increment() return nil } -// WritePackets implements stack.LinkEndpoint.WritePackets. +// WritePackets implements stack.NetworkEndpoint.WritePackets. func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("not implemented") @@ -441,6 +500,23 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe for pb := pkts.Front(); pb != nil; pb = pb.Next() { e.addIPHeader(r, pb, params) + if e.packetMustBeFragmented(pb, gso) { + current := pb + _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + // Modify the packet list in place with the new fragments. + pkts.InsertAfter(current, fragPkt) + current = current.Next() + return nil + }) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + return 0, err + } + // The fragmented packet can be released. The rest of the packets can be + // processed. + pkts.Remove(pb) + pb = current + } } // iptables filtering. All packets that reach here are locally @@ -451,8 +527,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. - n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + } return n, err } r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) @@ -466,7 +545,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if _, ok := natPkts[pkt]; ok { netHeader := header.IPv6(pkt.NetworkHeader().View()) - if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil { + if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { src := netHeader.SourceAddress() dst := netHeader.DestinationAddress() route := r.ReverseRoute(src, dst) @@ -475,8 +554,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe continue } } - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped))) // Dropped packets aren't errors, so include them in // the return value. return n + len(dropped), err @@ -536,7 +616,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { return } - for firstHeader := true; ; firstHeader = false { + for { + // Keep track of the start of the previous header so we can report the + // special case of a Hop by Hop at a location other than at the start. + previousHeaderStart := it.HeaderOffset() extHdr, done, err := it.Next() if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() @@ -550,11 +633,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6HopByHopOptionsExtHdr: // As per RFC 8200 section 4.1, the Hop By Hop extension header is // restricted to appear immediately after an IPv6 fixed header. - // - // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 - // (unrecognized next header) error in response to an extension header's - // Next Header field with the Hop By Hop extension header identifier. - if !firstHeader { + if previousHeaderStart != 0 { + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownHeader, + pointer: previousHeaderStart, + }, pkt) return } @@ -576,13 +659,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6OptionUnknownActionSkip: case header.IPv6OptionUnknownActionDiscard: return - case header.IPv6OptionUnknownActionDiscardSendICMP: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. - return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. + if header.IsV6MulticastAddress(r.LocalAddress) { + return + } + fallthrough + case header.IPv6OptionUnknownActionDiscardSendICMP: + // This case satisfies a requirement of RFC 8200 section 4.2 + // which states that an unknown option starting with bits [10] should: + // + // discard the packet and, regardless of whether or not the + // packet's Destination Address was a multicast address, send an + // ICMP Parameter Problem, Code 2, message to the packet's + // Source Address, pointing to the unrecognized Option Type. + // + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownOption, + pointer: it.ParseOffset() + optsIt.OptionOffset(), + respondToMulticast: true, + }, pkt) return default: panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt)) @@ -593,16 +688,20 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // As per RFC 8200 section 4.4, if a node encounters a routing header with // an unrecognized routing type value, with a non-zero Segments Left // value, the node must discard the packet and send an ICMP Parameter - // Problem, Code 0. If the Segments Left is 0, the node must ignore the - // Routing extension header and process the next header in the packet. + // Problem, Code 0 to the packet's Source Address, pointing to the + // unrecognized Routing Type. + // + // If the Segments Left is 0, the node must ignore the Routing extension + // header and process the next header in the packet. // // Note, the stack does not yet handle any type of routing extension // header, so we just make sure Segments Left is zero before processing // the next extension header. - // - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 0 for - // unrecognized routing types with a non-zero Segments Left value. if extHdr.SegmentsLeft() != 0 { + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6ErroneousHeader, + pointer: it.ParseOffset(), + }, pkt) return } @@ -737,13 +836,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6OptionUnknownActionSkip: case header.IPv6OptionUnknownActionDiscard: return - case header.IPv6OptionUnknownActionDiscardSendICMP: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. - return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. + if header.IsV6MulticastAddress(r.LocalAddress) { + return + } + fallthrough + case header.IPv6OptionUnknownActionDiscardSendICMP: + // This case satisfies a requirement of RFC 8200 section 4.2 + // which states that an unknown option starting with bits [10] should: + // + // discard the packet and, regardless of whether or not the + // packet's Destination Address was a multicast address, send an + // ICMP Parameter Problem, Code 2, message to the packet's + // Source Address, pointing to the unrecognized Option Type. + // + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownOption, + pointer: it.ParseOffset() + optsIt.OptionOffset(), + respondToMulticast: true, + }, pkt) return default: panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt)) @@ -767,8 +878,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { pkt.TransportProtocolNumber = p e.handleICMP(r, pkt, hasFragmentHeader) } else { - // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error - // in response to unrecognized next header values. + r.Stats().IP.PacketsDelivered.Increment() switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { case stack.TransportPacketHandled: case stack.TransportPacketDestinationPortUnreachable: @@ -777,18 +887,41 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // message with Code 4 in response to a packet for which the // transport protocol (e.g., UDP) has no listener, if that transport // protocol has no alternative means to inform the sender. - _ = returnError(r, &icmpReasonPortUnreachable{}, pkt) + _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt) + case stack.TransportPacketProtocolUnreachable: + // As per RFC 8200 section 4. (page 7): + // Extension headers are numbered from IANA IP Protocol Numbers + // [IANA-PN], the same values used for IPv4 and IPv6. When + // processing a sequence of Next Header values in a packet, the + // first one that is not an extension header [IANA-EH] indicates + // that the next item in the packet is the corresponding upper-layer + // header. + // With more related information on page 8: + // If, as a result of processing a header, the destination node is + // required to proceed to the next header but the Next Header value + // in the current header is unrecognized by the node, it should + // discard the packet and send an ICMP Parameter Problem message to + // the source of the packet, with an ICMP Code value of 1 + // ("unrecognized Next Header type encountered") and the ICMP + // Pointer field containing the offset of the unrecognized value + // within the original packet. + // + // Which when taken together indicate that an unknown protocol should + // be treated as an unrecognized next header value. + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownHeader, + pointer: it.ParseOffset(), + }, pkt) default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) } } default: - // If we receive a packet for an extension header we do not yet handle, - // drop the packet for now. - // - // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error - // in response to unrecognized next header values. + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownHeader, + pointer: it.ParseOffset(), + }, pkt) r.Stats().UnknownProtocolRcvdPackets.Increment() return } @@ -1097,6 +1230,9 @@ type protocol struct { eps map[*endpoint]struct{} } + ids []uint32 + hashIV uint32 + // defaultTTL is the current default TTL for the protocol. Only the // uint8 portion of it is meaningful. // @@ -1157,7 +1293,6 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ nic: nic, - linkEP: nic.LinkEndpoint(), linkAddrCache: linkAddrCache, nud: nud, dispatcher: dispatcher, @@ -1318,10 +1453,15 @@ type Options struct { func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { opts.NDPConfigs.validate() + ids := hash.RandN32(buckets) + hashIV := hash.RandN32(1)[0] + return func(s *stack.Stack) stack.NetworkProtocol { p := &protocol{ stack: s, - fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout, s.Clock()), + fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()), + ids: ids, + hashIV: hashIV, ndpDisp: opts.NDPDisp, ndpConfigs: opts.NDPConfigs, @@ -1339,3 +1479,73 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { func NewProtocol(s *stack.Stack) stack.NetworkProtocol { return NewProtocolWithOptions(Options{})(s) } + +// calculateFragmentInnerMTU calculates the maximum number of bytes of +// fragmentable data a fragment can have, based on the link layer mtu and pkt's +// network header size. +func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 { + // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are + // supported for outbound packets, their length should not affect the fragment + // MTU because they should only be transmitted once. + mtu -= uint32(pkt.NetworkHeader().View().Size()) + mtu -= header.IPv6FragmentHeaderSize + // Round the MTU down to align to 8 bytes. + mtu &^= 7 + if mtu <= maxPayloadSize { + return mtu + } + return maxPayloadSize +} + +func calculateFragmentReserve(pkt *stack.PacketBuffer) int { + return pkt.AvailableHeaderBytes() + pkt.NetworkHeader().View().Size() + header.IPv6FragmentHeaderSize +} + +// hashRoute calculates a hash value for the given route. It uses the source & +// destination address and 32-bit number to generate the hash. +func hashRoute(r *stack.Route, hashIV uint32) uint32 { + // The FNV-1a was chosen because it is a fast hashing algorithm, and + // cryptographic properties are not needed here. + h := fnv.New32a() + if _, err := h.Write([]byte(r.LocalAddress)); err != nil { + panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected to ever return an error", err)) + } + + if _, err := h.Write([]byte(r.RemoteAddress)); err != nil { + panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected to ever return an error", err)) + } + + s := make([]byte, 4) + binary.LittleEndian.PutUint32(s, hashIV) + if _, err := h.Write(s); err != nil { + panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected ever to return an error", err)) + } + + return h.Sum32() +} + +func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders header.IPv6, transportProto tcpip.TransportProtocolNumber, id uint32) (*stack.PacketBuffer, bool) { + fragPkt, offset, copied, more := pf.BuildNextFragment() + fragPkt.NetworkProtocolNumber = ProtocolNumber + + originalIPHeadersLength := len(originalIPHeaders) + fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize + fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength)) + + // Copy the IPv6 header and any extension headers already populated. + if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength { + panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got %d, want %d", copied, originalIPHeadersLength)) + } + fragmentIPHeaders.SetNextHeader(header.IPv6FragmentHeader) + fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize)) + + fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[originalIPHeadersLength:]) + fragmentHeader.Encode(&header.IPv6FragmentFields{ + M: more, + FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit), + Identification: id, + NextHeader: uint8(transportProto), + }) + + return fragPkt, more +} diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 94344057e..e792ca9e2 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -15,17 +15,21 @@ package ipv6 import ( + "encoding/hex" + "fmt" "math" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -136,6 +140,82 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst } } +func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error { + // sourcePacket does not have its IP Header populated. Let's copy the one + // from the first fragment. + source := header.IPv6(packets[0].NetworkHeader().View()) + sourceIPHeadersLen := len(source) + vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views()) + source = append(source, vv.ToView()...) + + var reassembledPayload buffer.VectorisedView + for i, fragment := range packets { + // Confirm that the packet is valid. + allBytes := buffer.NewVectorisedView(fragment.Size(), fragment.Views()) + fragmentIPHeaders := header.IPv6(allBytes.ToView()) + if !fragmentIPHeaders.IsValid(len(fragmentIPHeaders)) { + return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeaders)) + } + + fragmentIPHeadersLength := fragment.NetworkHeader().View().Size() + if fragmentIPHeadersLength != sourceIPHeadersLen { + return fmt.Errorf("fragment #%d: got fragmentIPHeadersLength = %d, want = %d", i, fragmentIPHeadersLength, sourceIPHeadersLen) + } + + if got := len(fragmentIPHeaders); got > int(mtu) { + return fmt.Errorf("fragment #%d: got len(fragmentIPHeaders) = %d, want <= %d", i, got, mtu) + } + + sourceIPHeader := source[:header.IPv6MinimumSize] + fragmentIPHeader := fragmentIPHeaders[:header.IPv6MinimumSize] + + if got := fragmentIPHeaders.PayloadLength(); got != wantFragments[i].payloadSize { + return fmt.Errorf("fragment #%d: got fragmentIPHeaders.PayloadLength() = %d, want = %d", i, got, wantFragments[i].payloadSize) + } + + // We expect the IPv6 Header to be similar across each fragment, besides the + // payload length. + sourceIPHeader.SetPayloadLength(0) + fragmentIPHeader.SetPayloadLength(0) + if diff := cmp.Diff(fragmentIPHeader, sourceIPHeader); diff != "" { + return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) + } + + if fragment.NetworkProtocolNumber != sourcePacket.NetworkProtocolNumber { + return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, fragment.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber) + } + + if len(packets) > 1 { + // If the source packet was big enough that it needed fragmentation, let's + // inspect the fragment header. Because no other extension headers are + // supported, it will always be the last extension header. + fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[fragmentIPHeadersLength-header.IPv6FragmentHeaderSize : fragmentIPHeadersLength]) + + if got := fragmentHeader.More(); got != wantFragments[i].more { + return fmt.Errorf("fragment #%d: got fragmentHeader.More() = %t, want = %t", i, got, wantFragments[i].more) + } + if got := fragmentHeader.FragmentOffset(); got != wantFragments[i].offset { + return fmt.Errorf("fragment #%d: got fragmentHeader.FragmentOffset() = %d, want = %d", i, got, wantFragments[i].offset) + } + if got := fragmentHeader.NextHeader(); got != uint8(proto) { + return fmt.Errorf("fragment #%d: got fragmentHeader.NextHeader() = %d, want = %d", i, got, uint8(proto)) + } + } + + // Store the reassembled payload as we parse each fragment. The payload + // includes the Transport header and everything after. + reassembledPayload.AppendView(fragment.TransportHeader().View()) + reassembledPayload.Append(fragment.Data) + } + + result := reassembledPayload.ToView() + if diff := cmp.Diff(result, buffer.View(source[sourceIPHeadersLen:])); diff != "" { + return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) + } + + return nil +} + // TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and // UDP packets destined to the IPv6 link-local all-nodes multicast address. func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { @@ -170,8 +250,6 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { // packets destined to the IPv6 solicited-node address of an assigned IPv6 // address. func TestReceiveOnSolicitedNodeAddr(t *testing.T) { - const nicID = 1 - tests := []struct { name string protocolFactory stack.TransportProtocolFactory @@ -195,7 +273,7 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { } s.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: header.IPv6EmptySubnet, NIC: nicID, }, @@ -295,17 +373,22 @@ func TestAddIpv6Address(t *testing.T) { } func TestReceiveIPv6ExtHdrs(t *testing.T) { - const nicID = 1 - tests := []struct { name string extHdr func(nextHdr uint8) ([]byte, uint8) shouldAccept bool + // Should we expect an ICMP response and if so, with what contents? + expectICMP bool + ICMPType header.ICMPv6Type + ICMPCode header.ICMPv6Code + pointer uint32 + multicast bool }{ { name: "None", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr }, shouldAccept: true, + expectICMP: false, }, { name: "hopbyhop with unknown option skippable action", @@ -336,9 +419,30 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, hopByHopExtHdrID }, shouldAccept: false, + expectICMP: false, + }, + { + name: "hopbyhop with unknown option discard and send icmp action (unicast)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. + }, hopByHopExtHdrID + }, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "hopbyhop with unknown option discard and send icmp action", + name: "hopbyhop with unknown option discard and send icmp action (multicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -348,12 +452,18 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP if option is unknown. 191, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. }, hopByHopExtHdrID }, + multicast: true, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "hopbyhop with unknown option discard and send icmp action unless multicast dest", + name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -364,39 +474,97 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP unless packet is for multicast destination if // option is unknown. 255, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. }, hopByHopExtHdrID }, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, + }, + { + name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. + }, hopByHopExtHdrID + }, + multicast: true, shouldAccept: false, + expectICMP: false, }, { - name: "routing with zero segments left", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 2, 3, 4, 5}, routingExtHdrID }, + name: "routing with zero segments left", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 1, 0, 2, 3, 4, 5, + }, routingExtHdrID + }, shouldAccept: true, }, { - name: "routing with non-zero segments left", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 1, 2, 3, 4, 5}, routingExtHdrID }, + name: "routing with non-zero segments left", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 1, 1, 2, 3, 4, 5, + }, routingExtHdrID + }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6ErroneousHeader, + pointer: header.IPv6FixedHeaderSize + 2, }, { - name: "atomic fragment with zero ID", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 0, 0, 0, 0}, fragmentExtHdrID }, + name: "atomic fragment with zero ID", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 0, 0, 0, 0, 0, 0, + }, fragmentExtHdrID + }, shouldAccept: true, }, { - name: "atomic fragment with non-zero ID", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 1, 2, 3, 4}, fragmentExtHdrID }, + name: "atomic fragment with non-zero ID", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 0, 0, 1, 2, 3, 4, + }, fragmentExtHdrID + }, shouldAccept: true, + expectICMP: false, }, { - name: "fragment", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 1, 2, 3, 4}, fragmentExtHdrID }, + name: "fragment", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 1, 0, 1, 2, 3, 4, + }, fragmentExtHdrID + }, shouldAccept: false, + expectICMP: false, }, { - name: "No next header", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID }, + name: "No next header", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{}, + noNextHdrID + }, shouldAccept: false, + expectICMP: false, }, { name: "destination with unknown option skippable action", @@ -412,6 +580,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, destinationExtHdrID }, shouldAccept: true, + expectICMP: false, }, { name: "destination with unknown option discard action", @@ -427,9 +596,10 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, destinationExtHdrID }, shouldAccept: false, + expectICMP: false, }, { - name: "destination with unknown option discard and send icmp action", + name: "destination with unknown option discard and send icmp action (unicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -439,12 +609,38 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP if option is unknown. 191, 6, 1, 2, 3, 4, 5, 6, + //^ 191 is an unknown option. }, destinationExtHdrID }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "destination with unknown option discard and send icmp action unless multicast dest", + name: "destination with unknown option discard and send icmp action (muilticast)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + //^ 191 is an unknown option. + }, destinationExtHdrID + }, + multicast: true, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, + }, + { + name: "destination with unknown option discard and send icmp action unless multicast dest (unicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -455,22 +651,33 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP unless packet is for multicast destination if // option is unknown. 255, 6, 1, 2, 3, 4, 5, 6, + //^ 255 is unknown. }, destinationExtHdrID }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "routing - atomic fragment", + name: "destination with unknown option discard and send icmp action unless multicast dest (multicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ - // Routing extension header. - fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, + nextHdr, 1, - // Fragment extension header. - nextHdr, 0, 0, 0, 1, 2, 3, 4, - }, routingExtHdrID + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + //^ 255 is unknown. + }, destinationExtHdrID }, - shouldAccept: true, + shouldAccept: false, + expectICMP: false, + multicast: true, }, { name: "atomic fragment - routing", @@ -504,12 +711,42 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { return []byte{ // Routing extension header. hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5, + // ^^^ The HopByHop extension header may not appear after the first + // extension header. // Hop By Hop extension header with skippable unknown option. nextHdr, 0, 62, 4, 1, 2, 3, 4, }, routingExtHdrID }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownHeader, + pointer: header.IPv6FixedHeaderSize, + }, + { + name: "routing - hop by hop (with send icmp unknown)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Routing extension header. + hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5, + // ^^^ The HopByHop extension header may not appear after the first + // extension header. + + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Skippable unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + }, routingExtHdrID + }, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownHeader, + pointer: header.IPv6FixedHeaderSize, }, { name: "No next header", @@ -553,6 +790,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, hopByHopExtHdrID }, shouldAccept: false, + expectICMP: false, }, { name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)", @@ -573,6 +811,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, hopByHopExtHdrID }, shouldAccept: false, + expectICMP: false, }, } @@ -582,7 +821,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - e := channel.New(0, 1280, linkAddr1) + e := channel.New(1, 1280, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -590,6 +829,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) } + // Add a default route so that a return packet knows where to go. + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&we, waiter.EventIn) @@ -631,12 +878,16 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Serialize IPv6 fixed header. payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + dstAddr := tcpip.Address(addr2) + if test.multicast { + dstAddr = header.IPv6AllNodesMulticastAddress + } ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), NextHeader: ipv6NextHdr, HopLimit: 255, SrcAddr: addr1, - DstAddr: addr2, + DstAddr: dstAddr, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -650,6 +901,44 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { t.Errorf("got UDP Rx Packets = %d, want = 0", got) } + if !test.expectICMP { + if p, ok := e.Read(); ok { + t.Fatalf("unexpected packet received: %#v", p) + } + return + } + + // ICMP required. + p, ok := e.Read() + if !ok { + t.Fatalf("expected packet wasn't written out") + } + + // Pack the output packet into a single buffer.View as the checkers + // assume that. + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + pkt := vv.ToView() + if got, want := len(pkt), header.IPv6FixedHeaderSize+header.ICMPv6MinimumSize+hdr.UsedLength(); got != want { + t.Fatalf("got an ICMP packet of size = %d, want = %d", got, want) + } + + ipHdr := header.IPv6(pkt) + checker.IPv6(t, ipHdr, checker.ICMPv6( + checker.ICMPv6Type(test.ICMPType), + checker.ICMPv6Code(test.ICMPCode))) + + // We know we are looking at no extension headers in the error ICMP + // packets. + icmpPkt := header.ICMPv6(ipHdr.Payload()) + // We know we sent small packets that won't be truncated when reflected + // back to us. + originalPacket := icmpPkt.Payload() + if got, want := icmpPkt.TypeSpecific(), test.pointer; got != want { + t.Errorf("unexpected ICMPv6 pointer, got = %d, want = %d\n", got, want) + } + if diff := cmp.Diff(hdr.View(), buffer.View(originalPacket)); diff != "" { + t.Errorf("ICMPv6 payload mismatch (-want +got):\n%s", diff) + } return } @@ -683,7 +972,6 @@ type fragmentData struct { func TestReceiveIPv6Fragments(t *testing.T) { const ( - nicID = 1 udpPayload1Length = 256 udpPayload2Length = 128 // Used to test cases where the fragment blocks are not a multiple of @@ -1815,7 +2103,6 @@ func TestWriteStats(t *testing.T) { t.Run(test.name, func(t *testing.T) { ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets) rt := buildRoute(t, ep) - var pkts stack.PacketBufferList for i := 0; i < nPackets; i++ { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -1857,12 +2144,13 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" ) if err := s.AddAddress(1, ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(1, %d, _) failed: %s", ProtocolNumber, err) + t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err) } { - subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")) + mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff") + subnet, err := tcpip.NewSubnet(dst, mask) if err != nil { - t.Fatalf("NewSubnet(_, _) failed: %v", err) + t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err) } s.SetRouteTable([]tcpip.Route{{ Destination: subnet, @@ -1871,7 +2159,7 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { } rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ProtocolNumber, err) + t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s, want = nil", src, dst, ProtocolNumber, err) } return rt } @@ -1922,3 +2210,293 @@ func TestClearEndpointFromProtocolOnClose(t *testing.T) { } } } + +type fragmentInfo struct { + offset uint16 + more bool + payloadSize uint16 +} + +type fragmentationTestCase struct { + description string + mtu uint32 + gso *stack.GSO + transHdrLen int + extraHdrLen int + payloadSize int + wantFragments []fragmentInfo + expectedFrags int +} + +var fragmentationTests = []fragmentationTestCase{ + { + description: "No Fragmentation", + mtu: 1280, + gso: &stack.GSO{}, + transHdrLen: 0, + extraHdrLen: header.IPv6MinimumSize, + payloadSize: 1000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1000, more: false}, + }, + }, + { + description: "Fragmented", + mtu: 1280, + gso: &stack.GSO{}, + transHdrLen: 0, + extraHdrLen: header.IPv6MinimumSize, + payloadSize: 2000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1240, more: true}, + {offset: 154, payloadSize: 776, more: false}, + }, + }, + { + description: "No fragmentation with big header", + mtu: 2000, + gso: &stack.GSO{}, + transHdrLen: 100, + extraHdrLen: header.IPv6MinimumSize, + payloadSize: 1000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1100, more: false}, + }, + }, + { + description: "Fragmented with gso nil", + mtu: 1280, + gso: nil, + transHdrLen: 0, + extraHdrLen: header.IPv6MinimumSize, + payloadSize: 1400, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1240, more: true}, + {offset: 154, payloadSize: 176, more: false}, + }, + }, + { + description: "Fragmented with big header", + mtu: 1280, + gso: &stack.GSO{}, + transHdrLen: 100, + extraHdrLen: header.IPv6MinimumSize, + payloadSize: 1200, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1240, more: true}, + {offset: 154, payloadSize: 76, more: false}, + }, + }, + { + description: "Fragmented with big header and prependable bytes", + mtu: 1280, + gso: &stack.GSO{}, + transHdrLen: 20, + extraHdrLen: header.IPv6MinimumSize + 66, + payloadSize: 1500, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1240, more: true}, + {offset: 154, payloadSize: 296, more: false}, + }, + }, +} + +func TestFragmentation(t *testing.T) { + const ( + ttl = 42 + tos = stack.DefaultTOS + transportProto = tcp.ProtocolNumber + ) + + for _, ft := range fragmentationTests { + t.Run(ft.description, func(t *testing.T) { + pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + source := pkt.Clone() + ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + r := buildRoute(t, ep) + err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: ttl, + TOS: stack.DefaultTOS, + }, pkt) + if err != nil { + t.Fatalf("WritePacket(_, _, _): = %s", err) + } + if got := len(ep.WrittenPackets); got != len(ft.wantFragments) { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments)) + } + if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) { + t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments)) + } + if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) + } + if len(ep.WrittenPackets) > 0 { + if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + t.Error(err) + } + } + }) + } +} + +func TestFragmentationWritePackets(t *testing.T) { + const ttl = 42 + tests := []struct { + description string + insertBefore int + insertAfter int + }{ + { + description: "Single packet", + insertBefore: 0, + insertAfter: 0, + }, + { + description: "With packet before", + insertBefore: 1, + insertAfter: 0, + }, + { + description: "With packet after", + insertBefore: 0, + insertAfter: 1, + }, + { + description: "With packet before and after", + insertBefore: 1, + insertAfter: 1, + }, + } + tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber) + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + for _, ft := range fragmentationTests { + t.Run(ft.description, func(t *testing.T) { + var pkts stack.PacketBufferList + for i := 0; i < test.insertBefore; i++ { + pkts.PushBack(tinyPacket.Clone()) + } + pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + source := pkt + pkts.PushBack(pkt.Clone()) + for i := 0; i < test.insertAfter; i++ { + pkts.PushBack(tinyPacket.Clone()) + } + + ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + r := buildRoute(t, ep) + + wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter + n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: ttl, + TOS: stack.DefaultTOS, + }) + if n != wantTotalPackets || err != nil { + t.Errorf("got WritePackets(_, _, _) = (%d, %s), want = (%d, nil)", n, err, wantTotalPackets) + } + if got := len(ep.WrittenPackets); got != wantTotalPackets { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets) + } + if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets { + t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets) + } + if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) + } + + if wantTotalPackets == 0 { + return + } + + fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore] + if err := compareFragments(fragments, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + t.Error(err) + } + }) + } + }) + } +} + +// TestFragmentationErrors checks that errors are returned from WritePacket +// correctly. +func TestFragmentationErrors(t *testing.T) { + const ttl = 42 + + tests := []struct { + description string + mtu uint32 + transHdrLen int + payloadSize int + allowPackets int + outgoingErrors int + mockError *tcpip.Error + wantError *tcpip.Error + }{ + { + description: "No frag", + mtu: 2000, + payloadSize: 1000, + transHdrLen: 0, + allowPackets: 0, + outgoingErrors: 1, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, + }, + { + description: "Error on first frag", + mtu: 1300, + payloadSize: 3000, + transHdrLen: 0, + allowPackets: 0, + outgoingErrors: 3, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, + }, + { + description: "Error on second frag", + mtu: 1500, + payloadSize: 4000, + transHdrLen: 0, + allowPackets: 1, + outgoingErrors: 2, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, + }, + { + description: "Error on packet with MTU smaller than transport header", + mtu: 1280, + transHdrLen: 1500, + payloadSize: 500, + allowPackets: 0, + outgoingErrors: 1, + mockError: nil, + wantError: tcpip.ErrMessageTooLong, + }, + } + + for _, ft := range tests { + t.Run(ft.description, func(t *testing.T) { + pkt := testutil.MakeRandPkt(ft.transHdrLen, header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) + r := buildRoute(t, ep) + err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: ttl, + TOS: stack.DefaultTOS, + }, pkt) + if err != ft.wantError { + t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError) + } + if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { + t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) + } + if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 48a4c65e3..40da011f8 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -1289,7 +1289,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt // // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by // LinkEndpoint.LinkAddress) before reaching this point. - linkAddr := ndp.ep.linkEP.LinkAddress() + linkAddr := ndp.ep.nic.LinkAddress() if !header.IsValidUnicastEthernetAddress(linkAddr) { return false } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 25464a03a..9033a9ed5 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -410,7 +410,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { naDst tcpip.Address }{ { - name: "Unspecified source to multicast destination", + name: "Unspecified source to solicited-node multicast destination", nsOpts: nil, nsSrcLinkAddr: remoteLinkAddr0, nsSrc: header.IPv6Any, @@ -437,11 +437,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { nsSrcLinkAddr: remoteLinkAddr0, nsSrc: header.IPv6Any, nsDst: nicAddr, - nsInvalid: false, - naDstLinkAddr: remoteLinkAddr0, - naSolicited: false, - naSrc: nicAddr, - naDst: header.IPv6AllNodesMulticastAddress, + nsInvalid: true, }, { name: "Unspecified source with source ll option to unicast destination", diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD index c9e57dc0d..d0ffc299a 100644 --- a/pkg/tcpip/network/testutil/BUILD +++ b/pkg/tcpip/network/testutil/BUILD @@ -8,6 +8,7 @@ go_library( "testutil.go", ], visibility = [ + "//pkg/tcpip/network/fragmentation:__pkg__", "//pkg/tcpip/network/ipv4:__pkg__", "//pkg/tcpip/network/ipv6:__pkg__", ], diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 2eaeab779..eba97334e 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -56,7 +56,6 @@ go_library( srcs = [ "addressable_endpoint_state.go", "conntrack.go", - "forwarder.go", "headertype_string.go", "icmp_rate_limit.go", "iptables.go", @@ -73,6 +72,7 @@ go_library( "nud.go", "packet_buffer.go", "packet_buffer_list.go", + "pending_packets.go", "rand.go", "registration.go", "route.go", @@ -123,7 +123,6 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/ports", @@ -139,7 +138,7 @@ go_test( name = "stack_test", size = "small", srcs = [ - "forwarder_test.go", + "forwarding_test.go", "linkaddrcache_test.go", "neighbor_cache_test.go", "neighbor_entry_test.go", diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index db8ac1c2b..4d3acab96 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -679,11 +679,6 @@ type addressState struct { } } -// NetworkEndpoint implements AddressEndpoint. -func (a *addressState) NetworkEndpoint() NetworkEndpoint { - return a.addressableEndpointState.networkEndpoint -} - // AddressWithPrefix implements AddressEndpoint. func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix { return a.addr diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 91ae7dafc..0cd1da11f 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -196,13 +196,14 @@ type bucket struct { // packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid // TCP header. +// +// Preconditions: pkt.NetworkHeader() is valid. func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { - // TODO(gvisor.dev/issue/170): Need to support for other - // protocols as well. - netHeader := header.IPv4(pkt.NetworkHeader().View()) - if len(netHeader) < header.IPv4MinimumSize || netHeader.TransportProtocol() != header.TCPProtocolNumber { + netHeader := pkt.Network() + if netHeader.TransportProtocol() != header.TCPProtocolNumber { return tupleID{}, tcpip.ErrUnknownProtocol } + tcpHeader := header.TCP(pkt.TransportHeader().View()) if len(tcpHeader) < header.TCPMinimumSize { return tupleID{}, tcpip.ErrUnknownProtocol @@ -214,7 +215,7 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { dstAddr: netHeader.DestinationAddress(), dstPort: tcpHeader.DestinationPort(), transProto: netHeader.TransportProtocol(), - netProto: header.IPv4ProtocolNumber, + netProto: pkt.NetworkProtocolNumber, }, nil } @@ -268,7 +269,7 @@ func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { return nil, dirOriginal } -func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { +func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn { tid, err := packetToTupleID(pkt) if err != nil { return nil @@ -344,7 +345,7 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { return } - netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader := pkt.Network() tcpHeader := header.TCP(pkt.TransportHeader().View()) // For prerouting redirection, packets going in the original direction @@ -366,8 +367,12 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { // support cases when they are validated, e.g. when we can't offload // receive checksumming. - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) + // After modification, IPv4 packets need a valid checksum. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + } } // handlePacketOutput manipulates ports for packets in Output hook. @@ -377,7 +382,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d return } - netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader := pkt.Network() tcpHeader := header.TCP(pkt.TransportHeader().View()) // For output redirection, packets going in the original direction @@ -396,7 +401,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d // Calculate the TCP checksum and set it. tcpHeader.SetChecksum(0) - length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength()) + length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length) if gso != nil && gso.NeedsCsum { tcpHeader.SetChecksum(xsum) @@ -405,8 +410,11 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) } - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + } } // handlePacket will manipulate the port and address of the packet if the @@ -422,7 +430,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou } // TODO(gvisor.dev/issue/170): Support other transport protocols. - if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber { + if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { return false } @@ -473,7 +481,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { } // We only track TCP connections. - if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber { + if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { return } @@ -609,7 +617,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo return true } -func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) { +func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) { // Lookup the connection. The reply's original destination // describes the original address. tid := tupleID{ @@ -618,7 +626,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint1 dstAddr: epID.RemoteAddress, dstPort: epID.RemotePort, transProto: header.TCPProtocolNumber, - netProto: header.IPv4ProtocolNumber, + netProto: netProto, } conn, _ := ct.connForTID(tid) if conn == nil { diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarding_test.go index 4e4b00a92..cf042309e 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -48,10 +48,9 @@ const ( type fwdTestNetworkEndpoint struct { AddressableEndpointState - nicID tcpip.NICID + nic NetworkInterface proto *fwdTestNetworkProtocol dispatcher TransportDispatcher - ep LinkEndpoint } var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) @@ -67,7 +66,7 @@ func (*fwdTestNetworkEndpoint) Enabled() bool { func (*fwdTestNetworkEndpoint) Disable() {} func (f *fwdTestNetworkEndpoint) MTU() uint32 { - return f.ep.MTU() - uint32(f.MaxHeaderLength()) + return f.nic.MTU() - uint32(f.MaxHeaderLength()) } func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { @@ -80,7 +79,7 @@ func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { - return f.ep.MaxHeaderLength() + fwdTestNetHeaderLen + return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen } func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { @@ -99,7 +98,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH b[srcAddrOffset] = r.LocalAddress[0] b[protocolNumberOffset] = byte(params.Protocol) - return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt) + return f.nic.WritePacket(r, gso, fwdTestNetNumber, pkt) } // WritePackets implements LinkEndpoint.WritePackets. @@ -159,10 +158,9 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint { e := &fwdTestNetworkEndpoint{ - nicID: nic.ID(), + nic: nic, proto: f, dispatcher: dispatcher, - ep: nic.LinkEndpoint(), } e.AddressableEndpointState.Init(e) return e diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index faa503b00..8d6d9a7f1 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -502,11 +502,11 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // OriginalDst returns the original destination of redirected connections. It // returns an error if the connection doesn't exist or isn't redirected. -func (it *IPTables) OriginalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) { +func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) { it.mu.RLock() defer it.mu.RUnlock() if !it.modified { return "", 0, tcpip.ErrNotConnected } - return it.connections.originalDst(epID) + return it.connections.originalDst(epID, netProto) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 08063f6ff..538c4625d 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -34,7 +34,7 @@ func (at *AcceptTarget) ID() TargetID { } // Action implements Target.Action. -func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 } @@ -52,7 +52,7 @@ func (dt *DropTarget) ID() TargetID { } // Action implements Target.Action. -func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } @@ -76,7 +76,7 @@ func (et *ErrorTarget) ID() TargetID { } // Action implements Target.Action. -func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } @@ -99,7 +99,7 @@ func (uc *UserChainTarget) ID() TargetID { } // Action implements Target.Action. -func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -118,7 +118,7 @@ func (rt *ReturnTarget) ID() TargetID { } // Action implements Target.Action. -func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } @@ -153,7 +153,7 @@ func (rt *RedirectTarget) ID() TargetID { // TODO(gvisor.dev/issue/170): Parse headers without copying. The current // implementation only works for PREROUTING and calls pkt.Clone(), neither // of which should be the case. -func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { // Packet is already manipulated. if pkt.NatDone { return RuleAccept, 0 @@ -164,11 +164,15 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso return RuleDrop, 0 } - // Change the address to localhost (127.0.0.1) in Output and - // to primary address of the incoming interface in Prerouting. + // Change the address to localhost (127.0.0.1 or ::1) in Output and to + // the primary address of the incoming interface in Prerouting. switch hook { case Output: - rt.Addr = tcpip.Address([]byte{127, 0, 0, 1}) + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + rt.Addr = tcpip.Address([]byte{127, 0, 0, 1}) + } else { + rt.Addr = header.IPv6Loopback + } case Prerouting: rt.Addr = address default: @@ -177,8 +181,7 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if // we need to change dest address (for OUTPUT chain) or ports. - netHeader := header.IPv4(pkt.NetworkHeader().View()) - switch protocol := netHeader.TransportProtocol(); protocol { + switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: udpHeader := header.UDP(pkt.TransportHeader().View()) udpHeader.SetDestinationPort(rt.Port) @@ -186,10 +189,10 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso // Calculate UDP checksum and set it. if hook == Output { udpHeader.SetChecksum(0) - length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength()) // Only calculate the checksum if offloading isn't supported. if r.Capabilities()&CapabilityTXChecksumOffload == 0 { + length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) xsum := r.PseudoHeaderChecksum(protocol, length) for _, v := range pkt.Data.Views() { xsum = header.Checksum(v, xsum) @@ -198,10 +201,15 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) } } - // Change destination address. - netHeader.SetDestinationAddress(rt.Addr) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) + + pkt.Network().SetDestinationAddress(rt.Addr) + + // After modification, IPv4 packets need a valid checksum. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + } pkt.NatDone = true case header.TCPProtocolNumber: if ct == nil { diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 27e1feec0..4df288798 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -131,10 +131,17 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA defer entry.mu.Unlock() switch s := entry.neigh.State; s { - case Reachable, Static: + case Stale: + entry.handlePacketQueuedLocked() + fallthrough + case Reachable, Static, Delay, Probe: + // As per RFC 4861 section 7.3.3: + // "Neighbor Unreachability Detection operates in parallel with the sending + // of packets to a neighbor. While reasserting a neighbor's reachability, + // a node continues sending packets to that neighbor using the cached + // link-layer address." return entry.neigh, nil, nil - - case Unknown, Incomplete, Stale, Delay, Probe: + case Unknown, Incomplete: entry.addWakerLocked(w) if entry.done == nil { @@ -147,10 +154,8 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA entry.handlePacketQueuedLocked() return entry.neigh, entry.done, tcpip.ErrWouldBlock - case Failed: return entry.neigh, nil, tcpip.ErrNoLinkAddress - default: panic(fmt.Sprintf("Invalid cache entry state: %s", s)) } diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index a0b7da5cd..fcd54ed83 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -1500,24 +1500,26 @@ func TestNeighborCacheReplace(t *testing.T) { } // Verify the entry exists - e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) - } - if doneCh != nil { - t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh) - } - if t.Failed() { - t.FailNow() - } - want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, - } - if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff) + { + e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != nil { + t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + if doneCh != nil { + t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh) + } + if t.Failed() { + t.FailNow() + } + want := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff) + } } // Notify of a link address change @@ -1536,28 +1538,34 @@ func TestNeighborCacheReplace(t *testing.T) { IsRouter: false, }) - // Requesting the entry again should start address resolution + // Requesting the entry again should start neighbor reachability confirmation. + // + // Verify the entry's new link address and the new state. { - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != nil { + t.Fatalf("neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) } - clock.Advance(config.DelayFirstProbeTime + typicalLatency) - select { - case <-doneCh: - default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + want := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: updatedLinkAddr, + State: Delay, } + if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + } + clock.Advance(config.DelayFirstProbeTime + typicalLatency) } - // Verify the entry's new link address + // Verify that the neighbor is now reachable. { e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) clock.Advance(typicalLatency) if err != nil { t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) } - want = NeighborEntry{ + want := NeighborEntry{ Addr: entry.Addr, LocalAddr: entry.LocalAddr, LinkAddr: updatedLinkAddr, diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 9a72bec79..4d69a4de1 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -236,7 +236,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.linkEP); err != nil { + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.LinkEndpoint); err != nil { // There is no need to log the error here; the NUD implementation may // assume a working link. A valid link should be the responsibility of // the NIC/stack.LinkEndpoint. @@ -277,7 +277,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.linkEP); err != nil { + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.LinkEndpoint); err != nil { e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index a265fff0a..e79abebca 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -227,8 +227,9 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e clock := faketime.NewManualClock() disp := testNUDDispatcher{} nic := NIC{ - id: entryTestNICID, - linkEP: nil, // entryTestLinkResolver doesn't use a LinkEndpoint + LinkEndpoint: nil, // entryTestLinkResolver doesn't use a LinkEndpoint + + id: entryTestNICID, stack: &Stack{ clock: clock, nudDisp: &disp, diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 23022292c..8828cc5fe 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -32,14 +32,18 @@ var _ NetworkInterface = (*NIC)(nil) // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { + LinkEndpoint + stack *Stack id tcpip.NICID name string - linkEP LinkEndpoint context NICContext - stats NICStats - neigh *neighborCache + stats NICStats + neigh *neighborCache + + // The network endpoints themselves may be modified by calling the interface's + // methods, but the map reference and entries must be constant. networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. @@ -88,10 +92,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // of IPv6 is supported on this endpoint's LinkEndpoint. nic := &NIC{ + LinkEndpoint: ep, + stack: stack, id: id, name: name, - linkEP: ep, context: ctx, stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), @@ -127,11 +132,15 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic) } - nic.linkEP.Attach(nic) + nic.LinkEndpoint.Attach(nic) return nic } +func (n *NIC) getNetworkEndpoint(proto tcpip.NetworkProtocolNumber) NetworkEndpoint { + return n.networkEndpoints[proto] +} + // Enabled implements NetworkInterface. func (n *NIC) Enabled() bool { return atomic.LoadUint32(&n.enabled) == 1 @@ -211,10 +220,9 @@ func (n *NIC) remove() *tcpip.Error { for _, ep := range n.networkEndpoints { ep.Close() } - n.networkEndpoints = nil // Detach from link endpoint, so no packet comes in. - n.linkEP.Attach(nil) + n.LinkEndpoint.Attach(nil) return nil } @@ -234,7 +242,64 @@ func (n *NIC) isPromiscuousMode() bool { // IsLoopback implements NetworkInterface. func (n *NIC) IsLoopback() bool { - return n.linkEP.Capabilities()&CapabilityLoopback != 0 + return n.LinkEndpoint.Capabilities()&CapabilityLoopback != 0 +} + +// WritePacket implements NetworkLinkEndpoint. +func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { + // As per relevant RFCs, we should queue packets while we wait for link + // resolution to complete. + // + // RFC 1122 section 2.3.2.2 (for IPv4): + // The link layer SHOULD save (rather than discard) at least + // one (the latest) packet of each set of packets destined to + // the same unresolved IP address, and transmit the saved + // packet when the address has been resolved. + // + // RFC 4861 section 5.2 (for IPv6): + // Once the IP address of the next-hop node is known, the sender + // examines the Neighbor Cache for link-layer information about that + // neighbor. If no entry exists, the sender creates one, sets its state + // to INCOMPLETE, initiates Address Resolution, and then queues the data + // packet pending completion of address resolution. + if ch, err := r.Resolve(nil); err != nil { + if err == tcpip.ErrWouldBlock { + r := r.Clone() + n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt) + return nil + } + return err + } + + return n.writePacket(r, gso, protocol, pkt) +} + +func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { + // WritePacket takes ownership of pkt, calculate numBytes first. + numBytes := pkt.Size() + + if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil { + return err + } + + n.stats.Tx.Packets.Increment() + n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + return nil +} + +// WritePackets implements NetworkLinkEndpoint. +func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + // TODO(gvisor.dev/issue/4458): Queue packets whie link address resolution + // is being peformed like WritePacket. + writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol) + n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) + writtenBytes := 0 + for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { + writtenBytes += pb.Size() + } + + n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) + return writtenPackets, err } // setSpoofing enables or disables address spoofing. @@ -483,9 +548,9 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) { r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) + defer r.Release() r.RemoteLinkAddress = remotelinkAddr - addressEndpoint.NetworkEndpoint().HandlePacket(&r, pkt) - addressEndpoint.DecRef() + n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) } // DeliverNetworkPacket finds the appropriate network protocol endpoint and @@ -519,7 +584,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // If no local link layer address is provided, assume it was sent // directly to this NIC. if local == "" { - local = n.linkEP.LinkAddress() + local = n.LinkEndpoint.LinkAddress() } // Are any packet type sockets listening for this network protocol? @@ -599,11 +664,11 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp n := r.nic if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil { if n.isValidForOutgoing(addressEndpoint) { - r.LocalLinkAddress = n.linkEP.LinkAddress() + r.LocalLinkAddress = n.LinkEndpoint.LinkAddress() r.RemoteLinkAddress = remote r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. - addressEndpoint.NetworkEndpoint().HandlePacket(&r, pkt) + n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) addressEndpoint.DecRef() r.Release() return @@ -614,21 +679,21 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // n doesn't have a destination endpoint. // Send the packet out of n. - // TODO(b/128629022): move this logic to route.WritePacket. // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. - if ch, err := r.Resolve(nil); err != nil { - if err == tcpip.ErrWouldBlock { - n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt) - // forwarder will release route. - return - } + + // pkt may have set its header and may not have enough headroom for + // link-layer header for the other link to prepend. Here we create a new + // packet to forward. + fwdPkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()), + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + }) + + // TODO(b/143425874) Decrease the TTL field in forwarded packets. + if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil { n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - r.Release() - return } - // The link-address resolution finished immediately. - n.forwardPacket(&r, protocol, pkt) r.Release() return } @@ -652,43 +717,18 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc p.PktType = tcpip.PacketOutgoing // Add the link layer header as outgoing packets are intercepted // before the link layer header is created. - n.linkEP.AddHeader(local, remote, protocol, p) + n.LinkEndpoint.AddHeader(local, remote, protocol, p) ep.HandlePacket(n.id, local, protocol, p) } } -func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - // TODO(b/143425874) Decrease the TTL field in forwarded packets. - - // pkt may have set its header and may not have enough headroom for link-layer - // header for the other link to prepend. Here we create a new packet to - // forward. - fwdPkt := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: int(n.linkEP.MaxHeaderLength()), - Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), - }) - - // WritePacket takes ownership of fwdPkt, calculate numBytes first. - numBytes := fwdPkt.Size() - - if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, fwdPkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() - return - } - - n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) -} - // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { - // TODO(gvisor.dev/issue/4365): Let the caller know that the transport - // protocol is unrecognized. n.stack.stats.UnknownProtocolRcvdPackets.Increment() - return TransportPacketHandled + return TransportPacketProtocolUnreachable } transProto := state.proto @@ -792,11 +832,6 @@ func (n *NIC) Name() string { return n.name } -// LinkEndpoint implements NetworkInterface. -func (n *NIC) LinkEndpoint() LinkEndpoint { - return n.linkEP -} - // nudConfigs gets the NUD configurations for n. func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) { if n.neigh == nil { diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index fdd49b77f..97a96af62 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -33,8 +33,7 @@ var _ NDPEndpoint = (*testIPv6Endpoint)(nil) type testIPv6Endpoint struct { AddressableEndpointState - nicID tcpip.NICID - linkEP LinkEndpoint + nic NetworkInterface protocol *testIPv6Protocol invalidatedRtr tcpip.Address @@ -57,12 +56,12 @@ func (*testIPv6Endpoint) DefaultTTL() uint8 { // MTU implements NetworkEndpoint.MTU. func (e *testIPv6Endpoint) MTU() uint32 { - return e.linkEP.MTU() - header.IPv6MinimumSize + return e.nic.MTU() - header.IPv6MinimumSize } // MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength. func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { - return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize + return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } // WritePacket implements NetworkEndpoint.WritePacket. @@ -134,8 +133,7 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint implements NetworkProtocol.NewEndpoint. func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint { e := &testIPv6Endpoint{ - nicID: nic.ID(), - linkEP: nic.LinkEndpoint(), + nic: nic, protocol: p, } e.AddressableEndpointState.Init(e) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index a7d9d59fa..105583c49 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" ) type headerType int @@ -255,6 +256,20 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { return newPk } +// Network returns the network header as a header.Network. +// +// Network should only be called when NetworkHeader has been set. +func (pk *PacketBuffer) Network() header.Network { + switch netProto := pk.NetworkProtocolNumber; netProto { + case header.IPv4ProtocolNumber: + return header.IPv4(pk.NetworkHeader().View()) + case header.IPv6ProtocolNumber: + return header.IPv6(pk.NetworkHeader().View()) + default: + panic(fmt.Sprintf("unknown network protocol number %d", netProto)) + } +} + // headerInfo stores metadata about a header in a packet. type headerInfo struct { // buf is the memorized slice for both prepended and consumed header. diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/pending_packets.go index 3eff141e6..f838eda8d 100644 --- a/pkg/tcpip/stack/forwarder.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -29,60 +29,60 @@ const ( ) type pendingPacket struct { - nic *NIC route *Route proto tcpip.NetworkProtocolNumber pkt *PacketBuffer } -type forwardQueue struct { +// packetsPendingLinkResolution is a queue of packets pending link resolution. +// +// Once link resolution completes successfully, the packets will be written. +type packetsPendingLinkResolution struct { sync.Mutex // The packets to send once the resolver completes. - packets map[<-chan struct{}][]*pendingPacket + packets map[<-chan struct{}][]pendingPacket // FIFO of channels used to cancel the oldest goroutine waiting for // link-address resolution. cancelChans []chan struct{} } -func newForwardQueue() *forwardQueue { - return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)} +func (f *packetsPendingLinkResolution) init() { + f.Lock() + defer f.Unlock() + f.packets = make(map[<-chan struct{}][]pendingPacket) } -func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - shouldWait := false - +func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { f.Lock() + defer f.Unlock() + packets, ok := f.packets[ch] - if !ok { - shouldWait = true - } - for len(packets) == maxPendingPacketsPerResolution { + if len(packets) == maxPendingPacketsPerResolution { p := packets[0] + packets[0] = pendingPacket{} packets = packets[1:] - p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() + p.route.Stats().IP.OutgoingPacketErrors.Increment() p.route.Release() } + if l := len(packets); l >= maxPendingPacketsPerResolution { panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution)) } - f.packets[ch] = append(packets, &pendingPacket{ - nic: n, + + f.packets[ch] = append(packets, pendingPacket{ route: r, - proto: protocol, + proto: proto, pkt: pkt, }) - f.Unlock() - if !shouldWait { + if ok { return } // Wait for the link-address resolution to complete. - // Start a goroutine with a forwarding-cancel channel so that we can - // limit the maximum number of goroutines running concurrently. - cancel := f.newCancelChannel() + cancel := f.newCancelChannelLocked() go func() { cancelled := false select { @@ -92,17 +92,21 @@ func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tc } f.Lock() - packets := f.packets[ch] + packets, ok := f.packets[ch] delete(f.packets, ch) f.Unlock() + if !ok { + panic(fmt.Sprintf("link-resolution goroutine woke up but no entry exists in the queue of packets")) + } + for _, p := range packets { if cancelled { - p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() + p.route.Stats().IP.OutgoingPacketErrors.Increment() } else if _, err := p.route.Resolve(nil); err != nil { - p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() + p.route.Stats().IP.OutgoingPacketErrors.Increment() } else { - p.nic.forwardPacket(p.route, p.proto, p.pkt) + p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt) } p.route.Release() } @@ -112,12 +116,10 @@ func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tc // newCancelChannel creates a channel that can cancel a pending forwarding // activity. The oldest channel is closed if the number of open channels would // exceed maxPendingResolutions. -func (f *forwardQueue) newCancelChannel() chan struct{} { - f.Lock() - defer f.Unlock() - +func (f *packetsPendingLinkResolution) newCancelChannelLocked() chan struct{} { if len(f.cancelChans) == maxPendingResolutions { ch := f.cancelChans[0] + f.cancelChans[0] = nil f.cancelChans = f.cancelChans[1:] close(ch) } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index b6f823b54..defb9129b 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -208,6 +208,10 @@ const ( // transport layer and callers need not take any further action. TransportPacketHandled TransportPacketDisposition = iota + // TransportPacketProtocolUnreachable indicates that the transport + // protocol requested in the packet is not supported. + TransportPacketProtocolUnreachable + // TransportPacketDestinationPortUnreachable indicates that there weren't any // listeners interested in the packet and the transport protocol has no means // to notify the sender. @@ -322,10 +326,6 @@ const ( // AssignableAddressEndpoint is a reference counted address endpoint that may be // assigned to a NetworkEndpoint. type AssignableAddressEndpoint interface { - // NetworkEndpoint returns the NetworkEndpoint the receiver is associated - // with. - NetworkEndpoint() NetworkEndpoint - // AddressWithPrefix returns the endpoint's address. AddressWithPrefix() tcpip.AddressWithPrefix @@ -475,6 +475,8 @@ type NDPEndpoint interface { // NetworkInterface is a network interface. type NetworkInterface interface { + NetworkLinkEndpoint + // ID returns the interface's ID. ID() tcpip.NICID @@ -488,9 +490,6 @@ type NetworkInterface interface { // Enabled returns true if the interface is enabled. Enabled() bool - - // LinkEndpoint returns the link endpoint backing the interface. - LinkEndpoint() LinkEndpoint } // NetworkEndpoint is the interface that needs to be implemented by endpoints @@ -663,22 +662,15 @@ const ( CapabilitySoftwareGSO ) -// LinkEndpoint is the interface implemented by data link layer protocols (e.g., -// ethernet, loopback, raw) and used by network layer protocols to send packets -// out through the implementer's data link endpoint. When a link header exists, -// it sets each PacketBuffer's LinkHeader field before passing it up the -// stack. -type LinkEndpoint interface { +// NetworkLinkEndpoint is a data-link layer that supports sending network +// layer packets. +type NetworkLinkEndpoint interface { // MTU is the maximum transmission unit for this endpoint. This is // usually dictated by the backing physical network; when such a // physical network doesn't exist, the limit is generally 64k, which // includes the maximum size of an IP packet. MTU() uint32 - // Capabilities returns the set of capabilities supported by the - // endpoint. - Capabilities() LinkEndpointCapabilities - // MaxHeaderLength returns the maximum size the data link (and // lower level layers combined) headers can have. Higher levels use this // information to reserve space in the front of the packets they're @@ -686,7 +678,7 @@ type LinkEndpoint interface { MaxHeaderLength() uint16 // LinkAddress returns the link address (typically a MAC) of the - // link endpoint. + // endpoint. LinkAddress() tcpip.LinkAddress // WritePacket writes a packet with the given protocol through the @@ -706,6 +698,19 @@ type LinkEndpoint interface { // offload is enabled. If it will be used for something else, it may // require to change syscall filters. WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) +} + +// LinkEndpoint is the interface implemented by data link layer protocols (e.g., +// ethernet, loopback, raw) and used by network layer protocols to send packets +// out through the implementer's data link endpoint. When a link header exists, +// it sets each PacketBuffer's LinkHeader field before passing it up the +// stack. +type LinkEndpoint interface { + NetworkLinkEndpoint + + // Capabilities returns the set of capabilities supported by the + // endpoint. + Capabilities() LinkEndpointCapabilities // WriteRawPacket writes a packet directly to the link. The packet // should already have an ethernet header. It takes ownership of vv. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 5ade3c832..25f80c1f8 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -72,21 +72,20 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip loop |= PacketLoop } - linkEP := nic.LinkEndpoint() r := Route{ NetProto: netProto, LocalAddress: localAddr, - LocalLinkAddress: linkEP.LinkAddress(), + LocalLinkAddress: nic.LinkEndpoint.LinkAddress(), RemoteAddress: remoteAddr, addressEndpoint: addressEndpoint, nic: nic, Loop: loop, } - if nic := r.nic; linkEP.Capabilities()&CapabilityResolutionRequired != 0 { - if linkRes, ok := nic.stack.linkAddrResolvers[r.NetProto]; ok { + if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { + if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes - r.linkCache = nic.stack + r.linkCache = r.nic.stack } } @@ -100,7 +99,7 @@ func (r *Route) NICID() tcpip.NICID { // MaxHeaderLength forwards the call to the network endpoint's implementation. func (r *Route) MaxHeaderLength() uint16 { - return r.addressEndpoint.NetworkEndpoint().MaxHeaderLength() + return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength() } // Stats returns a mutable copy of current stats. @@ -116,23 +115,17 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot // Capabilities returns the link-layer capabilities of the route. func (r *Route) Capabilities() LinkEndpointCapabilities { - return r.nic.LinkEndpoint().Capabilities() + return r.nic.LinkEndpoint.Capabilities() } // GSOMaxSize returns the maximum GSO packet size. func (r *Route) GSOMaxSize() uint32 { - if gso, ok := r.addressEndpoint.NetworkEndpoint().(GSOEndpoint); ok { + if gso, ok := r.nic.LinkEndpoint.(GSOEndpoint); ok { return gso.GSOMaxSize() } return 0 } -// ResolveWith immediately resolves a route with the specified remote link -// address. -func (r *Route) ResolveWith(addr tcpip.LinkAddress) { - r.RemoteLinkAddress = addr -} - // Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in // case address resolution requires blocking, e.g. wait for ARP reply. Waker is // notified when address resolution is complete (success or not). @@ -208,17 +201,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf return tcpip.ErrInvalidEndpointState } - // WritePacket takes ownership of pkt, calculate numBytes first. - numBytes := pkt.Size() - - err := r.addressEndpoint.NetworkEndpoint().WritePacket(r, gso, params, pkt) - if err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() - } else { - r.nic.stats.Tx.Packets.Increment() - r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) - } - return err + return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt) } // WritePackets writes a list of n packets through the given route and returns @@ -228,22 +211,7 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead return 0, tcpip.ErrInvalidEndpointState } - // WritePackets takes ownership of pkt, calculate length first. - numPkts := pkts.Len() - - n, err := r.addressEndpoint.NetworkEndpoint().WritePackets(r, gso, pkts, params) - if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(numPkts - n)) - } - r.nic.stats.Tx.Packets.IncrementBy(uint64(n)) - - writtenBytes := 0 - for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() { - writtenBytes += pb.Size() - } - - r.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) - return n, err + return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) } // WriteHeaderIncludedPacket writes a packet already containing a network @@ -253,32 +221,17 @@ func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first. - numBytes := pkt.Data.Size() - - if err := r.addressEndpoint.NetworkEndpoint().WriteHeaderIncludedPacket(r, pkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() - return err - } - r.nic.stats.Tx.Packets.Increment() - r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) - return nil + return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt) } // DefaultTTL returns the default TTL of the underlying network endpoint. func (r *Route) DefaultTTL() uint8 { - return r.addressEndpoint.NetworkEndpoint().DefaultTTL() + return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL() } // MTU returns the MTU of the underlying network endpoint. func (r *Route) MTU() uint32 { - return r.addressEndpoint.NetworkEndpoint().MTU() -} - -// NetworkProtocolNumber returns the NetworkProtocolNumber of the underlying -// network endpoint. -func (r *Route) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { - return r.addressEndpoint.NetworkEndpoint().NetworkProtocolNumber() + return r.nic.getNetworkEndpoint(r.NetProto).MTU() } // Release frees all resources associated with the route. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 57d8e79e0..3a07577c8 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -436,9 +436,9 @@ type Stack struct { // uniqueIDGenerator is a generator of unique identifiers. uniqueIDGenerator UniqueID - // forwarder holds the packets that wait for their link-address resolutions - // to complete, and forwards them when each resolution is done. - forwarder *forwardQueue + // linkResQueue holds packets that are waiting for link resolution to + // complete. + linkResQueue packetsPendingLinkResolution // randomGenerator is an injectable pseudo random generator that can be // used when a random number is required. @@ -550,8 +550,8 @@ type TransportEndpointInfo struct { // incompatible with the receiver. // // Preconditon: the parent endpoint mu must be held while calling this method. -func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.NetProto +func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto := t.NetProto switch len(addr.Addr) { case header.IPv4AddressSize: netProto = header.IPv4ProtocolNumber @@ -565,7 +565,7 @@ func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl } } - switch len(e.ID.LocalAddress) { + switch len(t.ID.LocalAddress) { case header.IPv4AddressSize: if len(addr.Addr) == header.IPv6AddressSize { return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState @@ -577,8 +577,8 @@ func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl } switch { - case netProto == e.NetProto: - case netProto == header.IPv4ProtocolNumber && e.NetProto == header.IPv6ProtocolNumber: + case netProto == t.NetProto: + case netProto == header.IPv4ProtocolNumber && t.NetProto == header.IPv6ProtocolNumber: if v6only { return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute } @@ -640,7 +640,6 @@ func New(opts Options) *Stack { useNeighborCache: opts.UseNeighborCache, uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, - forwarder: newForwardQueue(), randomGenerator: mathrand.New(randSrc), sendBufferSize: SendBufferSizeOption{ Min: MinBufferSize, @@ -653,6 +652,7 @@ func New(opts Options) *Stack { Max: DefaultMaxBufferSize, }, } + s.linkResQueue.init() // Add specified network protocols. for _, netProtoFactory := range opts.NetworkProtocols { @@ -928,16 +928,16 @@ func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { return s.CreateNICWithOptions(id, ep, NICOptions{}) } -// GetNICByName gets the NIC specified by name. -func (s *Stack) GetNICByName(name string) (*NIC, bool) { +// GetLinkEndpointByName gets the link endpoint specified by name. +func (s *Stack) GetLinkEndpointByName(name string) LinkEndpoint { s.mu.RLock() defer s.mu.RUnlock() for _, nic := range s.nics { if nic.Name() == name { - return nic, true + return nic.LinkEndpoint } } - return nil, false + return nil } // EnableNIC enables the given NIC so that the link-layer endpoint can start @@ -1062,13 +1062,13 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { } nics[id] = NICInfo{ Name: nic.name, - LinkAddress: nic.linkEP.LinkAddress(), + LinkAddress: nic.LinkEndpoint.LinkAddress(), ProtocolAddresses: nic.primaryAddresses(), Flags: flags, - MTU: nic.linkEP.MTU(), + MTU: nic.LinkEndpoint.MTU(), Stats: nic.stats, Context: nic.context, - ARPHardwareType: nic.linkEP.ARPHardwareType(), + ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(), } } return nics @@ -1323,7 +1323,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] - return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker) + return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.LinkEndpoint, waker) } // Neighbors returns all IP to MAC address associations. @@ -1539,7 +1539,7 @@ func (s *Stack) Wait() { s.mu.RLock() defer s.mu.RUnlock() for _, n := range s.nics { - n.linkEP.Wait() + n.LinkEndpoint.Wait() } } @@ -1627,7 +1627,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t // Add our own fake ethernet header. ethFields := header.EthernetFields{ - SrcAddr: nic.linkEP.LinkAddress(), + SrcAddr: nic.LinkEndpoint.LinkAddress(), DstAddr: dst, Type: netProto, } @@ -1636,7 +1636,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t vv := buffer.View(fakeHeader).ToVectorisedView() vv.Append(payload) - if err := nic.linkEP.WriteRawPacket(vv); err != nil { + if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil { return err } @@ -1653,7 +1653,7 @@ func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) return tcpip.ErrUnknownDevice } - if err := nic.linkEP.WriteRawPacket(payload); err != nil { + if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil { return err } @@ -1796,7 +1796,7 @@ func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtoco return nil, tcpip.ErrUnknownNICID } - return nic.networkEndpoints[proto], nil + return nic.getNetworkEndpoint(proto), nil } // NUDConfigurations gets the per-interface NUD configurations. @@ -1873,10 +1873,8 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres if addressEndpoint == nil { continue } - - ep := addressEndpoint.NetworkEndpoint() addressEndpoint.DecRef() - return ep, nil + return nic.getNetworkEndpoint(netProto), nil } return nil, tcpip.ErrBadAddress } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index aa20f750b..38994cca1 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -21,7 +21,6 @@ import ( "bytes" "fmt" "math" - "net" "sort" "testing" "time" @@ -35,7 +34,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -77,10 +75,9 @@ type fakeNetworkEndpoint struct { enabled bool } - nicID tcpip.NICID + nic stack.NetworkInterface proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher - ep stack.LinkEndpoint } func (f *fakeNetworkEndpoint) Enable() *tcpip.Error { @@ -103,7 +100,7 @@ func (f *fakeNetworkEndpoint) Disable() { } func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.ep.MTU() - uint32(f.MaxHeaderLength()) + return f.nic.MTU() - uint32(f.MaxHeaderLength()) } func (*fakeNetworkEndpoint) DefaultTTL() uint8 { @@ -135,7 +132,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { - return f.ep.MaxHeaderLength() + fakeNetHeaderLen + return f.nic.MaxHeaderLength() + fakeNetHeaderLen } func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { @@ -164,7 +161,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params return nil } - return f.ep.WritePacket(r, gso, fakeNetNumber, pkt) + return f.nic.WritePacket(r, gso, fakeNetNumber, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. @@ -216,10 +213,9 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &fakeNetworkEndpoint{ - nicID: nic.ID(), + nic: nic, proto: f, dispatcher: dispatcher, - ep: nic.LinkEndpoint(), } e.AddressableEndpointState.Init(e) return e @@ -2106,7 +2102,7 @@ func TestNICStats(t *testing.T) { t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) } - if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want { + if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want { t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } @@ -3502,52 +3498,6 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { } } -func TestResolveWith(t *testing.T) { - const ( - unspecifiedNICID = 0 - nicID = 1 - ) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, - }) - ep := channel.New(0, defaultMTU, "") - ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - addr := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), - PrefixLen: 24, - }, - } - if err := s.AddProtocolAddress(nicID, addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err) - } - - s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}) - - remoteAddr := tcpip.Address(net.ParseIP("192.168.1.59").To4()) - r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err) - } - defer r.Release() - - // Should initially require resolution. - if !r.IsResolutionRequired() { - t.Fatal("got r.IsResolutionRequired() = false, want = true") - } - - // Manually resolving the route should no longer require resolution. - r.ResolveWith("\x01") - if r.IsResolutionRequired() { - t.Fatal("got r.IsResolutionRequired() = true, want = false") - } -} - // TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its // associated address is removed should not cause a panic. func TestRouteReleaseAfterAddrRemoval(t *testing.T) { diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 06c7a3cd3..a4f141253 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -6,6 +6,8 @@ go_test( name = "integration_test", size = "small", srcs = [ + "forward_test.go", + "link_resolution_test.go", "loopback_test.go", "multicast_broadcast_test.go", ], @@ -15,6 +17,8 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/link/pipe", + "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go new file mode 100644 index 000000000..ffd38ee1a --- /dev/null +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -0,0 +1,378 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "net" + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/link/pipe" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +func TestForwarding(t *testing.T) { + const ( + host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + routerNIC1LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07") + routerNIC2LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08") + host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") + + host1NICID = 1 + routerNICID1 = 2 + routerNICID2 = 3 + host2NICID = 4 + + listenPort = 8080 + ) + + host1IPv4Addr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), + PrefixLen: 24, + }, + } + routerNIC1IPv4Addr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), + PrefixLen: 24, + }, + } + routerNIC2IPv4Addr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), + PrefixLen: 8, + }, + } + host2IPv4Addr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10.0.0.2").To4()), + PrefixLen: 8, + }, + } + host1IPv6Addr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("a::2").To16()), + PrefixLen: 64, + }, + } + routerNIC1IPv6Addr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("a::1").To16()), + PrefixLen: 64, + }, + } + routerNIC2IPv6Addr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("b::1").To16()), + PrefixLen: 64, + }, + } + host2IPv6Addr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("b::2").To16()), + PrefixLen: 64, + }, + } + + type endpointAndAddresses struct { + serverEP tcpip.Endpoint + serverAddr tcpip.Address + serverReadableCH chan struct{} + + clientEP tcpip.Endpoint + clientAddr tcpip.Address + clientReadableCH chan struct{} + } + + newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) { + t.Helper() + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + ep, err := s.NewEndpoint(transProto, netProto, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err) + } + + t.Cleanup(func() { + wq.EventUnregister(&we) + }) + + return ep, ch + } + + tests := []struct { + name string + epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses + }{ + { + name: "IPv4 host1 server with host2 client", + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { + ep1, ep1WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv4.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv4.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: host1IPv4Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: host2IPv4Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + } + }, + }, + { + name: "IPv6 host2 server with host1 client", + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { + ep1, ep1WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv6.ProtocolNumber) + ep2, ep2WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv6.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: host2IPv6Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: host1IPv6Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + } + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + } + + host1Stack := stack.New(stackOpts) + routerStack := stack.New(stackOpts) + host2Stack := stack.New(stackOpts) + + host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr, stack.CapabilityResolutionRequired) + routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired) + + if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil { + t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) + } + if err := routerStack.CreateNIC(routerNICID1, routerNIC1); err != nil { + t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) + } + if err := routerStack.CreateNIC(routerNICID2, routerNIC2); err != nil { + t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) + } + if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil { + t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) + } + + if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) + } + if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) + } + + if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) + } + if err := routerStack.AddAddress(routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress, err) + } + if err := routerStack.AddAddress(routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress, err) + } + if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) + } + + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) + } + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) + } + + host1Stack.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + tcpip.Route{ + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + tcpip.Route{ + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address, + NIC: host1NICID, + }, + tcpip.Route{ + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address, + NIC: host1NICID, + }, + }) + routerStack.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID1, + }, + tcpip.Route{ + Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID1, + }, + tcpip.Route{ + Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID2, + }, + tcpip.Route{ + Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID2, + }, + }) + host2Stack.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + tcpip.Route{ + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + tcpip.Route{ + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address, + NIC: host2NICID, + }, + tcpip.Route{ + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address, + NIC: host2NICID, + }, + }) + + epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack) + defer epsAndAddrs.serverEP.Close() + defer epsAndAddrs.clientEP.Close() + + serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort} + if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil { + t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err) + } + clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr} + if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil { + t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) + } + + write := func(ep tcpip.Endpoint, data []byte, to *tcpip.FullAddress) { + t.Helper() + + dataPayload := tcpip.SlicePayload(data) + wOpts := tcpip.WriteOptions{To: to} + n, ch, err := ep.Write(dataPayload, wOpts) + if err == tcpip.ErrNoLinkAddress { + // Wait for link resolution to complete. + <-ch + + n, _, err = ep.Write(dataPayload, wOpts) + } else if err != nil { + t.Fatalf("ep.Write(_, _): %s", err) + } + + if err != nil { + t.Fatalf("ep.Write(_, _): %s", err) + } + if want := int64(len(data)); n != want { + t.Fatalf("got ep.Write(_, _) = (%d, _, _), want = (%d, _, _)", n, want) + } + } + + data := []byte{1, 2, 3, 4} + write(epsAndAddrs.clientEP, data, &serverAddr) + + read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.Address) tcpip.FullAddress { + t.Helper() + + // Wait for the endpoint to be readable. + <-ch + + var addr tcpip.FullAddress + v, _, err := ep.Read(&addr) + if err != nil { + t.Fatalf("ep.Read(_): %s", err) + } + + if diff := cmp.Diff(v, buffer.View(data)); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + if addr.Addr != expectedFrom { + t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, expectedFrom) + } + + if t.Failed() { + t.FailNow() + } + + return addr + } + + addr := read(epsAndAddrs.serverReadableCH, epsAndAddrs.serverEP, data, epsAndAddrs.clientAddr) + // Unspecify the NIC since NIC IDs are meaningless across stacks. + addr.NIC = 0 + + data = tcpip.SlicePayload([]byte{5, 6, 7, 8, 9, 10, 11, 12}) + write(epsAndAddrs.serverEP, data, &addr) + addr = read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, epsAndAddrs.serverAddr) + if addr.Port != listenPort { + t.Errorf("got addr.Port = %d, want = %d", addr.Port, listenPort) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go new file mode 100644 index 000000000..bf3a6f6ee --- /dev/null +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -0,0 +1,219 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "net" + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/pipe" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/waiter" +) + +var ( + host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") + + host1IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), + PrefixLen: 24, + }, + } + host2IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), + PrefixLen: 8, + }, + } + host1IPv6Addr = tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("a::1").To16()), + PrefixLen: 64, + }, + } + host2IPv6Addr = tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("a::2").To16()), + PrefixLen: 64, + }, + } +) + +// TestPing tests that two hosts can ping eachother when link resolution is +// enabled. +func TestPing(t *testing.T) { + const ( + host1NICID = 1 + host2NICID = 4 + + // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo + // request/reply packets. + icmpDataOffset = 8 + ) + + tests := []struct { + name string + transProto tcpip.TransportProtocolNumber + netProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + icmpBuf func(*testing.T) buffer.View + }{ + { + name: "IPv4 Ping", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + remoteAddr: host2IPv4Addr.AddressWithPrefix.Address, + icmpBuf: func(t *testing.T) buffer.View { + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} + hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data))) + hdr.SetType(header.ICMPv4Echo) + if n := copy(hdr.Payload(), data[:]); n != len(data) { + t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) + } + return buffer.View(hdr) + }, + }, + { + name: "IPv6 Ping", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + remoteAddr: host2IPv6Addr.AddressWithPrefix.Address, + icmpBuf: func(t *testing.T) buffer.View { + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} + hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data))) + hdr.SetType(header.ICMPv6EchoRequest) + if n := copy(hdr.Payload(), data[:]); n != len(data) { + t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) + } + return buffer.View(hdr) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, + } + + host1Stack := stack.New(stackOpts) + host2Stack := stack.New(stackOpts) + + host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired) + + if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil { + t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) + } + if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil { + t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) + } + + if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) + } + if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) + } + + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) + } + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) + } + + host1Stack.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + tcpip.Route{ + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + }) + host2Stack.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + tcpip.Route{ + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + }) + + var wq waiter.Queue + we, waiterCH := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + ep, err := host1Stack.NewEndpoint(test.transProto, test.netProto, &wq) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err) + } + defer ep.Close() + + // The first write should trigger link resolution. + icmpBuf := test.icmpBuf(t) + wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}} + if _, ch, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got ep.Write(_, _) = %s, want = %s", err, tcpip.ErrNoLinkAddress) + } else { + // Wait for link resolution to complete. + <-ch + } + if n, _, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != nil { + t.Fatalf("ep.Write(_, _): %s", err) + } else if want := int64(len(icmpBuf)); n != want { + t.Fatalf("got ep.Write(_, _) = (%d, _, _), want = (%d, _, _)", n, want) + } + + // Wait for the endpoint to be readable. + <-waiterCH + + var addr tcpip.FullAddress + v, _, err := ep.Read(&addr) + if err != nil { + t.Fatalf("ep.Read(_): %s", err) + } + if diff := cmp.Diff(v[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + if addr.Addr != test.remoteAddr { + t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.remoteAddr) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 72d86b5ab..4f2ca7f54 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -203,7 +203,7 @@ func TestPingMulticastBroadcast(t *testing.T) { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst) } - src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(pkt.Pkt.NetworkHeader().View()) + src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader())) if src != expectedSrc { t.Errorf("got pkt source = %s, want = %s", src, expectedSrc) } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 6891fd245..0aaef495d 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -804,7 +804,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso pkt.Owner = owner pkt.EgressRoute = r pkt.GSOOptions = gso - pkt.NetworkProtocolNumber = r.NetworkProtocolNumber() + pkt.NetworkProtocolNumber = r.NetProto data.ReadToVV(&pkt.Data, packetSize) buildTCPHdr(r, tf, pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) @@ -1219,12 +1219,6 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { return true, nil } - // Increase counter if after processing the segment we would potentially - // advertise a zero window. - if crossed, above := e.windowCrossedACKThresholdLocked(-s.segMemSize()); crossed && !above { - e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() - } - // Now check if the received segment has caused us to transition // to a CLOSED state, if yes then terminate processing and do // not invoke the sender. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 7ad894840..3bcd3923a 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -248,6 +248,11 @@ type ReceiveErrors struct { // ZeroRcvWindowState is the number of times we advertised // a zero receive window when rcvList is full. ZeroRcvWindowState tcpip.StatCounter + + // WantZeroWindow is the number of times we wanted to advertise a + // zero receive window but couldn't because it would have caused + // the receive window's right edge to shrink. + WantZeroRcvWindow tcpip.StatCounter } // SendErrors collect segment send errors within the transport layer. @@ -1162,7 +1167,7 @@ func (e *endpoint) cleanupLocked() { // wndFromSpace returns the window that we can advertise based on the available // receive buffer space. func wndFromSpace(space int) int { - return space / (1 << rcvAdvWndScale) + return space >> rcvAdvWndScale } // initialReceiveWindow returns the initial receive window to advertise in the @@ -1518,6 +1523,38 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro return num, tcpip.ControlMessages{}, nil } +// selectWindowLocked returns the new window without checking for shrinking or scaling +// applied. +// Precondition: e.mu and e.rcvListMu must be held. +func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) { + wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked()) + maxWindow := wndFromSpace(e.rcvBufSize) + wndFromUsedBytes := maxWindow - e.rcvBufUsed + + // We take the lesser of the wndFromAvailable and wndFromUsedBytes because in + // cases where we receive a lot of small segments the segment overhead is a + // lot higher and we can run out socket buffer space before we can fill the + // previous window we advertised. In cases where we receive MSS sized or close + // MSS sized segments we will probably run out of window space before we + // exhaust receive buffer. + newWnd := wndFromAvailable + if newWnd > wndFromUsedBytes { + newWnd = wndFromUsedBytes + } + if newWnd < 0 { + newWnd = 0 + } + return seqnum.Size(newWnd) +} + +// selectWindow invokes selectWindowLocked after acquiring e.rcvListMu. +func (e *endpoint) selectWindow() (wnd seqnum.Size) { + e.rcvListMu.Lock() + wnd = e.selectWindowLocked() + e.rcvListMu.Unlock() + return wnd +} + // windowCrossedACKThresholdLocked checks if the receive window to be announced // would be under aMSS or under the window derived from half receive buffer, // whichever smaller. This is useful as a receive side silly window syndrome @@ -1534,7 +1571,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // // Precondition: e.mu and e.rcvListMu must be held. func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) { - newAvail := wndFromSpace(e.receiveBufferAvailableLocked()) + newAvail := int(e.selectWindowLocked()) oldAvail := newAvail - deltaBefore if oldAvail < 0 { oldAvail = 0 @@ -2099,7 +2136,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { case *tcpip.OriginalDestinationOption: e.LockUser() ipt := e.stack.IPTables() - addr, port, err := ipt.OriginalDst(e.ID) + addr, port, err := ipt.OriginalDst(e.ID, e.NetProto) e.UnlockUser() if err != nil { return err @@ -3013,6 +3050,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { EndSequence: rc.endSequence, FACK: rc.fack, RTT: rc.rtt, + Reord: rc.reorderSeen, } return s } diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index d969ca23a..d312b1b8b 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -29,26 +29,36 @@ import ( // // +stateify savable type rackControl struct { - // xmitTime is the latest transmission timestamp of rackControl.seg. - xmitTime time.Time `state:".(unixTime)"` - // endSequence is the ending TCP sequence number of rackControl.seg. endSequence seqnum.Value + // dsack indicates if the connection has seen a DSACK. + dsack bool + // fack is the highest selectively or cumulatively acknowledged // sequence. fack seqnum.Value + // minRTT is the estimated minimum RTT of the connection. + minRTT time.Duration + // rtt is the RTT of the most recently delivered packet on the // connection (either cumulatively acknowledged or selectively // acknowledged) that was not marked invalid as a possible spurious // retransmission. rtt time.Duration + + // reorderSeen indicates if reordering has been detected on this + // connection. + reorderSeen bool + + // xmitTime is the latest transmission timestamp of rackControl.seg. + xmitTime time.Time `state:".(unixTime)"` } -// Update will update the RACK related fields when an ACK has been received. +// update will update the RACK related fields when an ACK has been received. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 -func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration, offset uint32) { +func (rc *rackControl) update(seg *segment, ackSeg *segment, offset uint32) { rtt := time.Now().Sub(seg.xmitTime) // If the ACK is for a retransmitted packet, do not update if it is a @@ -65,12 +75,21 @@ func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration, return } } - if rtt < srtt { + if rtt < rc.minRTT { return } } rc.rtt = rtt + + // The sender can either track a simple global minimum of all RTT + // measurements from the connection, or a windowed min-filtered value + // of recent RTT measurements. This implementation keeps track of the + // simple global minimum of all RTTs for the connection. + if rtt < rc.minRTT || rc.minRTT == 0 { + rc.minRTT = rtt + } + // Update rc.xmitTime and rc.endSequence to the transmit time and // ending sequence number of the packet which has been acknowledged // most recently. @@ -80,3 +99,26 @@ func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration, rc.endSequence = endSeq } } + +// detectReorder detects if packet reordering has been observed. +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 +// * Step 3: Detect data segment reordering. +// To detect reordering, the sender looks for original data segments being +// delivered out of order. To detect such cases, the sender tracks the +// highest sequence selectively or cumulatively acknowledged in the RACK.fack +// variable. The name "fack" stands for the most "Forward ACK" (this term is +// adopted from [FACK]). If a never retransmitted segment that's below +// RACK.fack is (selectively or cumulatively) acknowledged, it has been +// delivered out of order. The sender sets RACK.reord to TRUE if such segment +// is identified. +func (rc *rackControl) detectReorder(seg *segment) { + endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) + if rc.fack.LessThan(endSeq) { + rc.fack = endSeq + return + } + + if endSeq.LessThan(rc.fack) && seg.xmitCount == 1 { + rc.reorderSeen = true + } +} diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 48bf196d8..8e0b7c843 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -43,6 +43,9 @@ type receiver struct { // rcvWnd is the non-scaled receive window last advertised to the peer. rcvWnd seqnum.Size + // rcvWUP is the rcvNxt value at the last window update sent. + rcvWUP seqnum.Value + rcvWndScale uint8 closed bool @@ -64,6 +67,7 @@ func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale rcvNxt: irs + 1, rcvAcc: irs.Add(rcvWnd + 1), rcvWnd: rcvWnd, + rcvWUP: irs + 1, rcvWndScale: rcvWndScale, lastRcvdAckTime: time.Now(), } @@ -84,34 +88,54 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize)) } +// currentWindow returns the available space in the window that was advertised +// last to our peer. +func (r *receiver) currentWindow() (curWnd seqnum.Size) { + endOfWnd := r.rcvWUP.Add(r.rcvWnd) + if endOfWnd.LessThan(r.rcvNxt) { + // return 0 if r.rcvNxt is past the end of the previously advertised window. + // This can happen because we accept a large segment completely even if + // accepting it causes it to partially exceed the advertised window. + return 0 + } + return r.rcvNxt.Size(endOfWnd) +} + // getSendParams returns the parameters needed by the sender when building // segments to send. func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { - avail := wndFromSpace(r.ep.receiveBufferAvailable()) - if avail == 0 { - // We have no space available to accept any data, move to zero window - // state. - r.rcvWnd = 0 - return r.rcvNxt, 0 - } - - acc := r.rcvNxt.Add(seqnum.Size(avail)) - newWnd := r.rcvNxt.Size(acc) - curWnd := r.rcvNxt.Size(r.rcvAcc) - + newWnd := r.ep.selectWindow() + curWnd := r.currentWindow() // Update rcvAcc only if new window is > previously advertised window. We // should never shrink the acceptable sequence space once it has been // advertised the peer. If we shrink the acceptable sequence space then we // would end up dropping bytes that might already be in flight. - if newWnd > curWnd { - r.rcvAcc = r.rcvNxt.Add(newWnd) + // ==================================================== sequence space. + // ^ ^ ^ ^ + // rcvWUP rcvNxt rcvAcc new rcvAcc + // <=====curWnd ===> + // <========= newWnd > curWnd ========= > + if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) { + // If the new window moves the right edge, then update rcvAcc. + r.rcvAcc = r.rcvNxt.Add(seqnum.Size(newWnd)) } else { + if newWnd == 0 { + // newWnd is zero but we can't advertise a zero as it would cause window + // to shrink so just increment a metric to record this event. + r.ep.stats.ReceiveErrors.WantZeroRcvWindow.Increment() + } newWnd = curWnd } // Stash away the non-scaled receive window as we use it for measuring // receiver's estimated RTT. r.rcvWnd = newWnd - return r.rcvNxt, r.rcvWnd >> r.rcvWndScale + r.rcvWUP = r.rcvNxt + scaledWnd := r.rcvWnd >> r.rcvWndScale + if scaledWnd == 0 { + // Increment a metric if we are advertising an actual zero window. + r.ep.stats.ReceiveErrors.ZeroRcvWindowState.Increment() + } + return r.rcvNxt, scaledWnd } // nonZeroWindow is called when the receive window grows from zero to nonzero; diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 13acaf753..1f9c5cf50 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -71,6 +71,9 @@ type segment struct { // xmitTime is the last transmit time of this segment. xmitTime time.Time `state:".(unixTime)"` xmitCount uint32 + + // acked indicates if the segment has already been SACKed. + acked bool } func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index c55589c45..6fa8d63cd 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -17,6 +17,7 @@ package tcp import ( "fmt" "math" + "sort" "sync/atomic" "time" @@ -263,6 +264,9 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint highRxt: iss, rescueRxt: iss, }, + rc: rackControl{ + fack: iss, + }, gso: ep.gso != nil, } @@ -1274,6 +1278,39 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { return true } +// Iterate the writeList and update RACK for each segment which is newly acked +// either cumulatively or selectively. Loop through the segments which are +// sacked, and update the RACK related variables and check for reordering. +// +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 +// steps 2 and 3. +func (s *sender) walkSACK(rcvdSeg *segment) { + // Sort the SACK blocks. The first block is the most recent unacked + // block. The following blocks can be in arbitrary order. + sackBlocks := make([]header.SACKBlock, len(rcvdSeg.parsedOptions.SACKBlocks)) + copy(sackBlocks, rcvdSeg.parsedOptions.SACKBlocks) + sort.Slice(sackBlocks, func(i, j int) bool { + return sackBlocks[j].Start.LessThan(sackBlocks[i].Start) + }) + + seg := s.writeList.Front() + for _, sb := range sackBlocks { + // This check excludes DSACK blocks. + if sb.Start.LessThanEq(rcvdSeg.ackNumber) || sb.Start.LessThanEq(s.sndUna) || s.sndNxt.LessThan(sb.End) { + continue + } + + for seg != nil && seg.sequenceNumber.LessThan(sb.End) && seg.xmitCount != 0 { + if sb.Start.LessThanEq(seg.sequenceNumber) && !seg.acked { + s.rc.update(seg, rcvdSeg, s.ep.tsOffset) + s.rc.detectReorder(seg) + seg.acked = true + } + seg = seg.Next() + } + } +} + // handleRcvdSegment is called when a segment is received; it is responsible for // updating the send-related state. func (s *sender) handleRcvdSegment(rcvdSeg *segment) { @@ -1308,6 +1345,21 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { rcvdSeg.hasNewSACKInfo = true } } + + // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08 + // section-7.2 + // * Step 2: Update RACK stats. + // If the ACK is not ignored as invalid, update the RACK.rtt + // to be the RTT sample calculated using this ACK, and + // continue. If this ACK or SACK was for the most recently + // sent packet, then record the RACK.xmit_ts timestamp and + // RACK.end_seq sequence implied by this ACK. + // * Step 3: Detect packet reordering. + // If the ACK selectively or cumulatively acknowledges an + // unacknowledged and also never retransmitted sequence below + // RACK.fack, then the corresponding packet has been + // reordered and RACK.reord is set to TRUE. + s.walkSACK(rcvdSeg) s.SetPipe() } @@ -1365,9 +1417,6 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { ackLeft := acked originalOutstanding := s.outstanding - s.rtt.Lock() - srtt := s.rtt.srtt - s.rtt.Unlock() for ackLeft > 0 { // We use logicalLen here because we can have FIN // segments (which are always at the end of list) that @@ -1388,13 +1437,14 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } // Update the RACK fields if SACK is enabled. - if s.ep.sackPermitted { - s.rc.Update(seg, rcvdSeg, srtt, s.ep.tsOffset) + if s.ep.sackPermitted && !seg.acked { + s.rc.update(seg, rcvdSeg, s.ep.tsOffset) + s.rc.detectReorder(seg) } s.writeList.Remove(seg) - // if SACK is enabled then Only reduce outstanding if + // If SACK is enabled then Only reduce outstanding if // the segment was not previously SACKED as these have // already been accounted for in SetPipe(). if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) { diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index e03f101e8..d3f92b48c 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -21,17 +21,20 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" ) +const ( + maxPayload = 10 + tsOptionSize = 12 + maxTCPOptionSize = 40 +) + // TestRACKUpdate tests the RACK related fields are updated when an ACK is // received on a SACK enabled connection. func TestRACKUpdate(t *testing.T) { - const maxPayload = 10 - const tsOptionSize = 12 - const maxTCPOptionSize = 40 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) defer c.Cleanup() @@ -49,7 +52,7 @@ func TestRACKUpdate(t *testing.T) { } if state.Sender.RACKState.RTT == 0 { - t.Fatalf("RACK RTT failed to update when an ACK is received") + t.Fatalf("RACK RTT failed to update when an ACK is received, got RACKState.RTT == 0 want != 0") } }) setStackSACKPermitted(t, c, true) @@ -69,6 +72,66 @@ func TestRACKUpdate(t *testing.T) { bytesRead := 0 c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) bytesRead += maxPayload - c.SendAck(790, bytesRead) + c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead) time.Sleep(200 * time.Millisecond) } + +// TestRACKDetectReorder tests that RACK detects packet reordering. +func TestRACKDetectReorder(t *testing.T) { + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) + defer c.Cleanup() + + const ackNum = 2 + + var n int + ch := make(chan struct{}) + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + gotSeq := state.Sender.RACKState.FACK + wantSeq := state.Sender.SndNxt + // FACK should be updated to the highest ending sequence number of the + // segment acknowledged most recently. + if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) { + t.Fatalf("RACK FACK failed to update, got: %v, but want: %v", gotSeq, wantSeq) + } + + n++ + if n < ackNum { + if state.Sender.RACKState.Reord { + t.Fatalf("RACK reorder detected when there is no reordering") + } + return + } + + if state.Sender.RACKState.Reord == false { + t.Fatalf("RACK reorder detection failed") + } + close(ch) + }) + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + data := buffer.NewView(ackNum * maxPayload) + for i := range data { + data[i] = byte(i) + } + + // Write the data. + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + bytesRead := 0 + for i := 0; i < ackNum; i++ { + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + bytesRead += maxPayload + } + + start := c.IRS.Add(maxPayload + 1) + end := start.Add(maxPayload) + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) + c.SendAck(seq, bytesRead) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + <-ch +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 5b504d0d1..a7149efd0 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -6264,14 +6264,27 @@ func TestReceiveBufferAutoTuning(t *testing.T) { rawEP.NextSeqNum-- rawEP.SendPacketWithTS(nil, tsVal) rawEP.NextSeqNum++ + if i == 0 { // In the first iteration the receiver based RTT is not // yet known as a result the moderation code should not // increase the advertised window. rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd)) } else { - pkt := c.GetPacket() - curRcvWnd = int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale + // Read loop above could generate an ACK if the window had dropped to + // zero and then read had opened it up. + lastACK := c.GetPacket() + // Discard any intermediate ACKs and only check the last ACK we get in a + // short time period of few ms. + for { + time.Sleep(1 * time.Millisecond) + pkt := c.GetPacketNonBlocking() + if pkt == nil { + break + } + lastACK = pkt + } + curRcvWnd = int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()) << c.WindowScale // If thew new current window is close maxReceiveBufferSize then terminate // the loop. This can happen before all iterations are done due to timing // differences when running the test. @@ -7328,7 +7341,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { // Write chunks of ~30000 bytes. It's important that two // payloads make it equal or longer than MSS. - remain := rcvBuf * 2 + remain := rcvBuf sent := 0 data := make([]byte, defaultMTU/2) @@ -7343,7 +7356,6 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { }) sent += len(data) remain -= len(data) - checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index faf51ef95..4d7847142 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -68,9 +68,9 @@ const ( // V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0. V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" - // testInitialSequenceNumber is the initial sequence number sent in packets that + // TestInitialSequenceNumber is the initial sequence number sent in packets that // are sent in response to a SYN or in the initial SYN sent to the stack. - testInitialSequenceNumber = 789 + TestInitialSequenceNumber = 789 ) // StackAddrWithPrefix is StackAddr with its associated prefix length. @@ -505,7 +505,7 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op checker.TCP( checker.DstPort(TestPort), checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), + checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -532,7 +532,7 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int checker.TCP( checker.DstPort(TestPort), checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), + checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -912,7 +912,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * // Build SYN-ACK. c.IRS = seqnum.Value(tcpSeg.SequenceNumber()) - iss := seqnum.Value(testInitialSequenceNumber) + iss := seqnum.Value(TestInitialSequenceNumber) c.SendPacket(nil, &Headers{ SrcPort: tcpSeg.DestinationPort(), DstPort: tcpSeg.SourcePort(), @@ -1084,7 +1084,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions offset += paddingToAdd // Send a SYN request. - iss := seqnum.Value(testInitialSequenceNumber) + iss := seqnum.Value(TestInitialSequenceNumber) c.SendPacket(nil, &Headers{ SrcPort: TestPort, DstPort: StackPort, diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 4e14a5fc5..b4604ba35 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1432,7 +1432,9 @@ func TestNoChecksum(t *testing.T) { var _ stack.NetworkInterface = (*testInterface)(nil) -type testInterface struct{} +type testInterface struct { + stack.NetworkLinkEndpoint +} func (*testInterface) ID() tcpip.NICID { return 0 @@ -1450,10 +1452,6 @@ func (*testInterface) Enabled() bool { return true } -func (*testInterface) LinkEndpoint() stack.LinkEndpoint { - return nil -} - func TestTTL(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1780,16 +1778,26 @@ func TestV4UnknownDestination(t *testing.T) { checker.ICMPv4Type(header.ICMPv4DstUnreachable), checker.ICMPv4Code(header.ICMPv4PortUnreachable))) + // We need to compare the included data part of the UDP packet that is in + // the ICMP packet with the matching original data. icmpPkt := header.ICMPv4(hdr.Payload()) payloadIPHeader := header.IPv4(icmpPkt.Payload()) + incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize wantLen := len(payload) if tc.largePayload { - wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize + // To work out the data size we need to simulate what the sender would + // have done. The wanted size is the total available minus the sum of + // the headers in the UDP AND ICMP packets, given that we know the test + // had only a minimal IP header but the ICMP sender will have allowed + // for a maximally sized packet header. + wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength + } - // In case of large payloads the IP packet may be truncated. Update + // In the case of large payloads the IP packet may be truncated. Update // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) + // Add back the two headers within the payload. + payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength)) origDgram := header.UDP(payloadIPHeader.Payload()) if got, want := len(origDgram.Payload()), wantLen; got != want { @@ -2015,7 +2023,8 @@ func TestPayloadModifiedV4(t *testing.T) { payload := newPayload() h := unicastV4.header4Tuple(incoming) buf := c.buildV4Packet(payload, &h) - // Modify the payload so that the checksum value in the UDP header will be incorrect. + // Modify the payload so that the checksum value in the UDP header will be + // incorrect. buf[len(buf)-1]++ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), @@ -2045,7 +2054,8 @@ func TestPayloadModifiedV6(t *testing.T) { payload := newPayload() h := unicastV6.header4Tuple(incoming) buf := c.buildV6Packet(payload, &h) - // Modify the payload so that the checksum value in the UDP header will be incorrect. + // Modify the payload so that the checksum value in the UDP header will be + // incorrect. buf[len(buf)-1]++ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), |