summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/header/ndp_options.go213
-rw-r--r--pkg/tcpip/header/ndp_test.go574
-rw-r--r--pkg/tcpip/seqnum/seqnum.go5
-rw-r--r--pkg/tcpip/stack/ndp.go18
-rw-r--r--pkg/tcpip/stack/ndp_test.go125
-rw-r--r--pkg/tcpip/transport/tcp/BUILD10
-rw-r--r--pkg/tcpip/transport/tcp/accept.go14
-rw-r--r--pkg/tcpip/transport/tcp/connect.go15
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go13
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go23
-rw-r--r--pkg/tcpip/transport/tcp/rcv_test.go74
-rw-r--r--pkg/tcpip/transport/tcpconntrack/BUILD1
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go13
13 files changed, 1055 insertions, 43 deletions
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
index 444e90820..5d3975c56 100644
--- a/pkg/tcpip/header/ndp_options.go
+++ b/pkg/tcpip/header/ndp_options.go
@@ -45,6 +45,10 @@ const (
// NDPRecursiveDNSServerOptionType is the type of the Recursive DNS
// Server option, as per RFC 8106 section 5.1.
NDPRecursiveDNSServerOptionType NDPOptionIdentifier = 25
+
+ // NDPDNSSearchListOptionType is the type of the DNS Search List option,
+ // as per RFC 8106 section 5.2.
+ NDPDNSSearchListOptionType = 31
)
const (
@@ -115,6 +119,27 @@ const (
// RFC 8106 section 5.3.1.
minNDPRecursiveDNSServerBodySize = 22
+ // ndpDNSSearchListLifetimeOffset is the start of the 4-byte
+ // Lifetime field within an NDPDNSSearchList.
+ ndpDNSSearchListLifetimeOffset = 2
+
+ // ndpDNSSearchListDomainNamesOffset is the start of the DNS search list
+ // domain names within an NDPDNSSearchList.
+ ndpDNSSearchListDomainNamesOffset = 6
+
+ // minNDPDNSSearchListBodySize is the minimum NDP DNS Search List option's
+ // body size when it contains at least one domain name, as per RFC 8106
+ // section 5.3.1.
+ minNDPDNSSearchListBodySize = 14
+
+ // maxDomainNameLabelLength is the maximum length of a domain name
+ // label, as per RFC 1035 section 3.1.
+ maxDomainNameLabelLength = 63
+
+ // maxDomainNameLength is the maximum length of a domain name, including
+ // label AND label length octet, as per RFC 1035 section 3.1.
+ maxDomainNameLength = 255
+
// lengthByteUnits is the multiplier factor for the Length field of an
// NDP option. That is, the length field for NDP options is in units of
// 8 octets, as per RFC 4861 section 4.6.
@@ -224,6 +249,14 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
return opt, false, nil
+ case NDPDNSSearchListOptionType:
+ opt := NDPDNSSearchList(body)
+ if err := opt.checkDomainNames(); err != nil {
+ return nil, true, err
+ }
+
+ return opt, false, nil
+
default:
// We do not yet recognize the option, just skip for
// now. This is okay because RFC 4861 allows us to
@@ -684,3 +717,183 @@ func (o NDPRecursiveDNSServer) iterAddresses(fn func(tcpip.Address)) error {
return nil
}
+
+// NDPDNSSearchList is the NDP DNS Search List option, as defined by
+// RFC 8106 section 5.2.
+type NDPDNSSearchList []byte
+
+// Type implements NDPOption.Type.
+func (o NDPDNSSearchList) Type() NDPOptionIdentifier {
+ return NDPDNSSearchListOptionType
+}
+
+// Length implements NDPOption.Length.
+func (o NDPDNSSearchList) Length() int {
+ return len(o)
+}
+
+// serializeInto implements NDPOption.serializeInto.
+func (o NDPDNSSearchList) serializeInto(b []byte) int {
+ used := copy(b, o)
+
+ // Zero out the reserved bytes that are before the Lifetime field.
+ for i := 0; i < ndpDNSSearchListLifetimeOffset; i++ {
+ b[i] = 0
+ }
+
+ return used
+}
+
+// String implements fmt.Stringer.String.
+func (o NDPDNSSearchList) String() string {
+ lt := o.Lifetime()
+ domainNames, err := o.DomainNames()
+ if err != nil {
+ return fmt.Sprintf("%T([] valid for %s; err = %s)", o, lt, err)
+ }
+ return fmt.Sprintf("%T(%s valid for %s)", o, domainNames, lt)
+}
+
+// Lifetime returns the length of time that the DNS search list of domain names
+// in this option may be used for name resolution.
+//
+// Note, a value of 0 implies the domain names should no longer be used,
+// and a value of infinity/forever is represented by NDPInfiniteLifetime.
+func (o NDPDNSSearchList) Lifetime() time.Duration {
+ // The field is the time in seconds, as per RFC 8106 section 5.1.
+ return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpDNSSearchListLifetimeOffset:]))
+}
+
+// DomainNames returns a DNS search list of domain names.
+//
+// DomainNames will parse the backing buffer as outlined by RFC 1035 section
+// 3.1 and return a list of strings, with all domain names in lower case.
+func (o NDPDNSSearchList) DomainNames() ([]string, error) {
+ var domainNames []string
+ return domainNames, o.iterDomainNames(func(domainName string) { domainNames = append(domainNames, domainName) })
+}
+
+// checkDomainNames iterates over the domain names in an NDP DNS Search List
+// option and returns any error it encounters.
+func (o NDPDNSSearchList) checkDomainNames() error {
+ return o.iterDomainNames(nil)
+}
+
+// iterDomainNames iterates over the domain names in an NDP DNS Search List
+// option and calls a function with each valid domain name.
+func (o NDPDNSSearchList) iterDomainNames(fn func(string)) error {
+ if l := len(o); l < minNDPDNSSearchListBodySize {
+ return fmt.Errorf("got %d bytes for NDP DNS Search List option's body, expected at least %d bytes: %w", l, minNDPDNSSearchListBodySize, io.ErrUnexpectedEOF)
+ }
+
+ var searchList bytes.Reader
+ searchList.Reset(o[ndpDNSSearchListDomainNamesOffset:])
+
+ var scratch [maxDomainNameLength]byte
+ domainName := bytes.NewBuffer(scratch[:])
+
+ // Parse the domain names, as per RFC 1035 section 3.1.
+ for searchList.Len() != 0 {
+ domainName.Reset()
+
+ // Parse a label within a domain name, as per RFC 1035 section 3.1.
+ for {
+ // The first byte is the label length.
+ labelLenByte, err := searchList.ReadByte()
+ if err != nil {
+ if err != io.EOF {
+ // ReadByte should only ever return nil or io.EOF.
+ panic(fmt.Sprintf("unexpected error when reading a label's length: %s", err))
+ }
+
+ // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected
+ // once we start parsing a domain name; we expect the buffer to contain
+ // enough bytes for the whole domain name.
+ return fmt.Errorf("unexpected exhausted buffer while parsing a new label for a domain from NDP Search List option: %w", io.ErrUnexpectedEOF)
+ }
+ labelLen := int(labelLenByte)
+
+ // A zero-length label implies the end of a domain name.
+ if labelLen == 0 {
+ // If the domain name is empty or we have no callback function, do
+ // nothing further with the current domain name.
+ if domainName.Len() == 0 || fn == nil {
+ break
+ }
+
+ // Ignore the trailing period in the parsed domain name.
+ domainName.Truncate(domainName.Len() - 1)
+ fn(domainName.String())
+ break
+ }
+
+ // The label's length must not exceed the maximum length for a label.
+ if labelLen > maxDomainNameLabelLength {
+ return fmt.Errorf("label length of %d bytes is greater than the max label length of %d bytes for an NDP Search List option: %w", labelLen, maxDomainNameLabelLength, ErrNDPOptMalformedBody)
+ }
+
+ // The label (and trailing period) must not make the domain name too long.
+ if labelLen+1 > domainName.Cap()-domainName.Len() {
+ return fmt.Errorf("label would make an NDP Search List option's domain name longer than the max domain name length of %d bytes: %w", maxDomainNameLength, ErrNDPOptMalformedBody)
+ }
+
+ // Copy the label and add a trailing period.
+ for i := 0; i < labelLen; i++ {
+ b, err := searchList.ReadByte()
+ if err != nil {
+ if err != io.EOF {
+ panic(fmt.Sprintf("unexpected error when reading domain name's label: %s", err))
+ }
+
+ return fmt.Errorf("read %d out of %d bytes for a domain name's label from NDP Search List option: %w", i, labelLen, io.ErrUnexpectedEOF)
+ }
+
+ // As per RFC 1035 section 2.3.1:
+ // 1) the label must only contain ASCII include letters, digits and
+ // hyphens
+ // 2) the first character in a label must be a letter
+ // 3) the last letter in a label must be a letter or digit
+
+ if !isLetter(b) {
+ if i == 0 {
+ return fmt.Errorf("first character of a domain name's label in an NDP Search List option must be a letter, got character code = %d: %w", b, ErrNDPOptMalformedBody)
+ }
+
+ if b == '-' {
+ if i == labelLen-1 {
+ return fmt.Errorf("last character of a domain name's label in an NDP Search List option must not be a hyphen (-): %w", ErrNDPOptMalformedBody)
+ }
+ } else if !isDigit(b) {
+ return fmt.Errorf("domain name's label in an NDP Search List option may only contain letters, digits and hyphens, got character code = %d: %w", b, ErrNDPOptMalformedBody)
+ }
+ }
+
+ // If b is an upper case character, make it lower case.
+ if isUpperLetter(b) {
+ b = b - 'A' + 'a'
+ }
+
+ if err := domainName.WriteByte(b); err != nil {
+ panic(fmt.Sprintf("unexpected error writing label to domain name buffer: %s", err))
+ }
+ }
+ if err := domainName.WriteByte('.'); err != nil {
+ panic(fmt.Sprintf("unexpected error writing trailing period to domain name buffer: %s", err))
+ }
+ }
+ }
+
+ return nil
+}
+
+func isLetter(b byte) bool {
+ return b >= 'a' && b <= 'z' || isUpperLetter(b)
+}
+
+func isUpperLetter(b byte) bool {
+ return b >= 'A' && b <= 'Z'
+}
+
+func isDigit(b byte) bool {
+ return b >= '0' && b <= '9'
+}
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
index 2341329c4..dc4591253 100644
--- a/pkg/tcpip/header/ndp_test.go
+++ b/pkg/tcpip/header/ndp_test.go
@@ -17,7 +17,9 @@ package header
import (
"bytes"
"errors"
+ "fmt"
"io"
+ "regexp"
"testing"
"time"
@@ -667,6 +669,477 @@ func TestNDPRecursiveDNSServerOption(t *testing.T) {
}
}
+// TestNDPDNSSearchListOption tests the getters of NDPDNSSearchList.
+func TestNDPDNSSearchListOption(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ lifetime time.Duration
+ domainNames []string
+ err error
+ }{
+ {
+ name: "Valid1Label",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 3, 'a', 'b', 'c',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: time.Second,
+ domainNames: []string{
+ "abc",
+ },
+ err: nil,
+ },
+ {
+ name: "Valid2Label",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 5,
+ 3, 'a', 'b', 'c',
+ 4, 'a', 'b', 'c', 'd',
+ 0,
+ 0, 0, 0, 0, 0, 0,
+ },
+ lifetime: 5 * time.Second,
+ domainNames: []string{
+ "abc.abcd",
+ },
+ err: nil,
+ },
+ {
+ name: "Valid3Label",
+ buf: []byte{
+ 0, 0,
+ 1, 0, 0, 0,
+ 3, 'a', 'b', 'c',
+ 4, 'a', 'b', 'c', 'd',
+ 1, 'e',
+ 0,
+ 0, 0, 0, 0,
+ },
+ lifetime: 16777216 * time.Second,
+ domainNames: []string{
+ "abc.abcd.e",
+ },
+ err: nil,
+ },
+ {
+ name: "Valid2Domains",
+ buf: []byte{
+ 0, 0,
+ 1, 2, 3, 4,
+ 3, 'a', 'b', 'c',
+ 0,
+ 2, 'd', 'e',
+ 3, 'x', 'y', 'z',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: 16909060 * time.Second,
+ domainNames: []string{
+ "abc",
+ "de.xyz",
+ },
+ err: nil,
+ },
+ {
+ name: "Valid3DomainsMixedCase",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 3, 'a', 'B', 'c',
+ 0,
+ 2, 'd', 'E',
+ 3, 'X', 'y', 'z',
+ 0,
+ 1, 'J',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: []string{
+ "abc",
+ "de.xyz",
+ "j",
+ },
+ err: nil,
+ },
+ {
+ name: "ValidDomainAfterNULL",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 3, 'a', 'B', 'c',
+ 0, 0, 0, 0,
+ 2, 'd', 'E',
+ 3, 'X', 'y', 'z',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: []string{
+ "abc",
+ "de.xyz",
+ },
+ err: nil,
+ },
+ {
+ name: "Valid0Domains",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 0,
+ 0, 0, 0, 0, 0, 0, 0,
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: nil,
+ },
+ {
+ name: "NoTrailingNull",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 7, 'a', 'b', 'c', 'd', 'e', 'f', 'g',
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "IncorrectLength",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 8, 'a', 'b', 'c', 'd', 'e', 'f', 'g',
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "IncorrectLengthWithNULL",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 7, 'a', 'b', 'c', 'd', 'e', 'f',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "LabelOfLength63",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: []string{
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk",
+ },
+ err: nil,
+ },
+ {
+ name: "LabelOfLength64",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 64, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "DomainNameOfLength255",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 62, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: []string{
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghij",
+ },
+ err: nil,
+ },
+ {
+ name: "DomainNameOfLength256",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "StartingDigitForLabel",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 3, '9', 'b', 'c',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: time.Second,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "StartingHyphenForLabel",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 3, '-', 'b', 'c',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: time.Second,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "EndingHyphenForLabel",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 3, 'a', 'b', '-',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: time.Second,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "EndingDigitForLabel",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 3, 'a', 'b', '9',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: time.Second,
+ domainNames: []string{
+ "ab9",
+ },
+ err: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opt := NDPDNSSearchList(test.buf)
+
+ if got := opt.Lifetime(); got != test.lifetime {
+ t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime)
+ }
+ domainNames, err := opt.DomainNames()
+ if !errors.Is(err, test.err) {
+ t.Errorf("opt.DomainNames() = %s", err)
+ }
+ if diff := cmp.Diff(domainNames, test.domainNames); diff != "" {
+ t.Errorf("mismatched domain names (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestNDPSearchListOptionDomainNameLabelInvalidSymbols(t *testing.T) {
+ for r := rune(0); r <= 255; r++ {
+ t.Run(fmt.Sprintf("RuneVal=%d", r), func(t *testing.T) {
+ buf := []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 3, 'a', 0 /* will be replaced */, 'c',
+ 0,
+ 0, 0, 0,
+ }
+ buf[8] = uint8(r)
+ opt := NDPDNSSearchList(buf)
+
+ // As per RFC 1035 section 2.3.1, the label must only include ASCII
+ // letters, digits and hyphens (a-z, A-Z, 0-9, -).
+ var expectedErr error
+ re := regexp.MustCompile(`[a-zA-Z0-9-]`)
+ if !re.Match([]byte{byte(r)}) {
+ expectedErr = ErrNDPOptMalformedBody
+ }
+
+ if domainNames, err := opt.DomainNames(); !errors.Is(err, expectedErr) {
+ t.Errorf("got opt.DomainNames() = (%s, %v), want = (_, %v)", domainNames, err, ErrNDPOptMalformedBody)
+ }
+ })
+ }
+}
+
+func TestNDPDNSSearchListOptionSerialize(t *testing.T) {
+ b := []byte{
+ 9, 8,
+ 1, 0, 0, 0,
+ 3, 'a', 'b', 'c',
+ 4, 'a', 'b', 'c', 'd',
+ 1, 'e',
+ 0,
+ }
+ targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
+ expected := []byte{
+ 31, 3, 0, 0,
+ 1, 0, 0, 0,
+ 3, 'a', 'b', 'c',
+ 4, 'a', 'b', 'c', 'd',
+ 1, 'e',
+ 0,
+ 0, 0, 0, 0,
+ }
+ opts := NDPOptions(targetBuf)
+ serializer := NDPOptionsSerializer{
+ NDPDNSSearchList(b),
+ }
+ if got, want := opts.Serialize(serializer), len(expected); got != want {
+ t.Errorf("got Serialize = %d, want = %d", got, want)
+ }
+ if !bytes.Equal(targetBuf, expected) {
+ t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected)
+ }
+
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPDNSSearchListOptionType {
+ t.Errorf("got Type = %d, want = %d", got, NDPDNSSearchListOptionType)
+ }
+
+ opt, ok := next.(NDPDNSSearchList)
+ if !ok {
+ t.Fatalf("next (type = %T) cannot be casted to an NDPDNSSearchList", next)
+ }
+ if got := opt.Type(); got != 31 {
+ t.Errorf("got Type = %d, want = 31", got)
+ }
+ if got := opt.Length(); got != 22 {
+ t.Errorf("got Length = %d, want = 22", got)
+ }
+ if got, want := opt.Lifetime(), 16777216*time.Second; got != want {
+ t.Errorf("got Lifetime = %s, want = %s", got, want)
+ }
+ domainNames, err := opt.DomainNames()
+ if err != nil {
+ t.Errorf("opt.DomainNames() = %s", err)
+ }
+ if diff := cmp.Diff(domainNames, []string{"abc.abcd.e"}); diff != "" {
+ t.Errorf("domain names mismatch (-want +got):\n%s", diff)
+ }
+
+ // Iterator should not return anything else.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+}
+
// TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions
// the iterator was returned for is malformed.
func TestNDPOptionsIterCheck(t *testing.T) {
@@ -832,6 +1305,107 @@ func TestNDPOptionsIterCheck(t *testing.T) {
},
expectedErr: ErrNDPOptMalformedBody,
},
+ {
+ name: "DNSSearchListLargeCompliantRFC1035",
+ buf: []byte{
+ 31, 33, 0, 0,
+ 0, 0, 0, 0,
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 62, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j',
+ 0,
+ },
+ expectedErr: nil,
+ },
+ {
+ name: "DNSSearchListNonCompliantRFC1035",
+ buf: []byte{
+ 31, 33, 0, 0,
+ 0, 0, 0, 0,
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ },
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "DNSSearchListValidSmall",
+ buf: []byte{
+ 31, 2, 0, 0,
+ 0, 0, 0, 0,
+ 6, 'a', 'b', 'c', 'd', 'e', 'f',
+ 0,
+ },
+ expectedErr: nil,
+ },
+ {
+ name: "DNSSearchListTooSmall",
+ buf: []byte{
+ 31, 1, 0, 0,
+ 0, 0, 0,
+ },
+ expectedErr: io.ErrUnexpectedEOF,
+ },
}
for _, test := range tests {
diff --git a/pkg/tcpip/seqnum/seqnum.go b/pkg/tcpip/seqnum/seqnum.go
index b40a3c212..d3bea7de4 100644
--- a/pkg/tcpip/seqnum/seqnum.go
+++ b/pkg/tcpip/seqnum/seqnum.go
@@ -46,11 +46,6 @@ func (v Value) InWindow(first Value, size Size) bool {
return v.InRange(first, first.Add(size))
}
-// Overlap checks if the window [a,a+b) overlaps with the window [x, x+y).
-func Overlap(a Value, b Size, x Value, y Size) bool {
- return a.LessThan(x.Add(y)) && x.LessThan(a.Add(b))
-}
-
// Add calculates the sequence number following the [v, v+s) window.
func (v Value) Add(s Size) Value {
return v + Value(s)
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index 8140c6dd4..193a9dfde 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -241,6 +241,16 @@ type NDPDispatcher interface {
// call functions on the stack itself.
OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration)
+ // OnDNSSearchListOption will be called when an NDP option with a DNS
+ // search list has been received.
+ //
+ // It is up to the caller to use the domain names in the search list
+ // for only their valid lifetime. OnDNSSearchListOption may be called
+ // with new or already known domain names. If called with known domain
+ // names, their valid lifetimes must be refreshed to lifetime (it may
+ // be increased, decreased or completely invalidated when lifetime = 0.
+ OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration)
+
// OnDHCPv6Configuration will be called with an updated configuration that is
// available via DHCPv6 for a specified NIC.
//
@@ -714,6 +724,14 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
addrs, _ := opt.Addresses()
ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), addrs, opt.Lifetime())
+ case header.NDPDNSSearchList:
+ if ndp.nic.stack.ndpDisp == nil {
+ continue
+ }
+
+ domainNames, _ := opt.DomainNames()
+ ndp.nic.stack.ndpDisp.OnDNSSearchListOption(ndp.nic.ID(), domainNames, opt.Lifetime())
+
case header.NDPPrefixInformation:
prefix := opt.Subnet()
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 6562a2d22..6dd460984 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -133,6 +133,12 @@ type ndpRDNSSEvent struct {
rdnss ndpRDNSS
}
+type ndpDNSSLEvent struct {
+ nicID tcpip.NICID
+ domainNames []string
+ lifetime time.Duration
+}
+
type ndpDHCPv6Event struct {
nicID tcpip.NICID
configuration stack.DHCPv6ConfigurationFromNDPRA
@@ -150,6 +156,8 @@ type ndpDispatcher struct {
rememberPrefix bool
autoGenAddrC chan ndpAutoGenAddrEvent
rdnssC chan ndpRDNSSEvent
+ dnsslC chan ndpDNSSLEvent
+ routeTable []tcpip.Route
dhcpv6ConfigurationC chan ndpDHCPv6Event
}
@@ -257,6 +265,17 @@ func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tc
}
}
+// Implements stack.NDPDispatcher.OnDNSSearchListOption.
+func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) {
+ if n.dnsslC != nil {
+ n.dnsslC <- ndpDNSSLEvent{
+ nicID,
+ domainNames,
+ lifetime,
+ }
+ }
+}
+
// Implements stack.NDPDispatcher.OnDHCPv6Configuration.
func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) {
if c := n.dhcpv6ConfigurationC; c != nil {
@@ -3386,6 +3405,112 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
}
}
+// TestNDPDNSSearchListDispatch tests that the integrator is informed when an
+// NDP DNS Search List option is received with at least one domain name in the
+// search list.
+func TestNDPDNSSearchListDispatch(t *testing.T) {
+ const nicID = 1
+
+ ndpDisp := ndpDispatcher{
+ dnsslC: make(chan ndpDNSSLEvent, 3),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ optSer := header.NDPOptionsSerializer{
+ header.NDPDNSSearchList([]byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 2, 'h', 'i',
+ 0,
+ }),
+ header.NDPDNSSearchList([]byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 1, 'i',
+ 0,
+ 2, 'a', 'm',
+ 2, 'm', 'e',
+ 0,
+ }),
+ header.NDPDNSSearchList([]byte{
+ 0, 0,
+ 0, 0, 1, 0,
+ 3, 'x', 'y', 'z',
+ 0,
+ 5, 'h', 'e', 'l', 'l', 'o',
+ 5, 'w', 'o', 'r', 'l', 'd',
+ 0,
+ 4, 't', 'h', 'i', 's',
+ 2, 'i', 's',
+ 1, 'a',
+ 4, 't', 'e', 's', 't',
+ 0,
+ }),
+ }
+ expected := []struct {
+ domainNames []string
+ lifetime time.Duration
+ }{
+ {
+ domainNames: []string{
+ "hi",
+ },
+ lifetime: 0,
+ },
+ {
+ domainNames: []string{
+ "i",
+ "am.me",
+ },
+ lifetime: time.Second,
+ },
+ {
+ domainNames: []string{
+ "xyz",
+ "hello.world",
+ "this.is.a.test",
+ },
+ lifetime: 256 * time.Second,
+ },
+ }
+
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer))
+
+ for i, expected := range expected {
+ select {
+ case dnssl := <-ndpDisp.dnsslC:
+ if dnssl.nicID != nicID {
+ t.Errorf("got %d-th dnssl nicID = %d, want = %d", i, dnssl.nicID, nicID)
+ }
+ if diff := cmp.Diff(dnssl.domainNames, expected.domainNames); diff != "" {
+ t.Errorf("%d-th dnssl domain names mismatch (-want +got):\n%s", i, diff)
+ }
+ if dnssl.lifetime != expected.lifetime {
+ t.Errorf("got %d-th dnssl lifetime = %s, want = %s", i, dnssl.lifetime, expected.lifetime)
+ }
+ default:
+ t.Fatal("expected a DNSSL event")
+ }
+ }
+
+ // Should have no more DNSSL options.
+ select {
+ case <-ndpDisp.dnsslC:
+ t.Fatal("unexpectedly got a DNSSL event")
+ default:
+ }
+}
+
// TestCleanupNDPState tests that all discovered routers and prefixes, and
// auto-generated addresses are invalidated when a NIC becomes a router.
func TestCleanupNDPState(t *testing.T) {
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index edb7718a6..61426623c 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -109,3 +109,13 @@ go_test(
"//runsc/testutil",
],
)
+
+go_test(
+ name = "rcv_test",
+ size = "small",
+ srcs = ["rcv_test.go"],
+ deps = [
+ ":tcp",
+ "//pkg/tcpip/seqnum",
+ ],
+)
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 5bb243e3b..e6a23c978 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -101,7 +101,7 @@ type listenContext struct {
// v6Only is true if listenEP is a dual stack socket and has the
// IPV6_V6ONLY option set.
- v6only bool
+ v6Only bool
// netProto indicates the network protocol(IPv4/v6) for the listening
// endpoint.
@@ -126,12 +126,12 @@ func timeStamp() uint32 {
}
// newListenContext creates a new listen context.
-func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
+func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
stack: stk,
rcvWnd: rcvWnd,
hasher: sha1.New(),
- v6only: v6only,
+ v6Only: v6Only,
netProto: netProto,
listenEP: listenEP,
pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint),
@@ -207,7 +207,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
netProto = s.route.NetProto
}
n := newEndpoint(l.stack, netProto, queue)
- n.v6only = l.v6only
+ n.v6only = l.v6Only
n.ID = s.id
n.boundNICID = s.route.NICID()
n.route = s.route.Clone()
@@ -293,7 +293,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
}
// Perform the 3-way handshake.
- h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
+ h := newPassiveHandshake(ep, ep.rcv.rcvWnd, isn, irs, opts, deferAccept)
if err := h.execute(); err != nil {
ep.mu.Unlock()
ep.Close()
@@ -613,8 +613,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// its own goroutine and is responsible for handling connection requests.
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Lock()
- v6only := e.v6only
- ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto)
+ v6Only := e.v6only
+ ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto)
defer func() {
// Mark endpoint as closed. This will prevent goroutines running
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 368865911..76e27bf26 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -105,24 +105,11 @@ type handshake struct {
}
func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
- rcvWndScale := ep.rcvWndScaleForHandshake()
-
- // Round-down the rcvWnd to a multiple of wndScale. This ensures that the
- // window offered in SYN won't be reduced due to the loss of precision if
- // window scaling is enabled after the handshake.
- rcvWnd = (rcvWnd >> uint8(rcvWndScale)) << uint8(rcvWndScale)
-
- // Ensure we can always accept at least 1 byte if the scale specified
- // was too high for the provided rcvWnd.
- if rcvWnd == 0 {
- rcvWnd = 1
- }
-
h := handshake{
ep: ep,
active: true,
rcvWnd: rcvWnd,
- rcvWndScale: int(rcvWndScale),
+ rcvWndScale: ep.rcvWndScaleForHandshake(),
}
h.resetState()
return h
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 7e8def82d..45f2aa78b 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1062,6 +1062,19 @@ func (e *endpoint) initialReceiveWindow() int {
if rcvWnd > routeWnd {
rcvWnd = routeWnd
}
+ rcvWndScale := e.rcvWndScaleForHandshake()
+
+ // Round-down the rcvWnd to a multiple of wndScale. This ensures that the
+ // window offered in SYN won't be reduced due to the loss of precision if
+ // window scaling is enabled after the handshake.
+ rcvWnd = (rcvWnd >> uint8(rcvWndScale)) << uint8(rcvWndScale)
+
+ // Ensure we can always accept at least 1 byte if the scale specified
+ // was too high for the provided rcvWnd.
+ if rcvWnd == 0 {
+ rcvWnd = 1
+ }
+
return rcvWnd
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index caf8977b3..a4b73b588 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -70,13 +70,24 @@ func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale
// acceptable checks if the segment sequence number range is acceptable
// according to the table on page 26 of RFC 793.
func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
- rcvWnd := r.rcvNxt.Size(r.rcvAcc)
- if rcvWnd == 0 {
- return segLen == 0 && segSeq == r.rcvNxt
- }
+ return Acceptable(segSeq, segLen, r.rcvNxt, r.rcvAcc)
+}
- return segSeq.InWindow(r.rcvNxt, rcvWnd) ||
- seqnum.Overlap(r.rcvNxt, rcvWnd, segSeq, segLen)
+// Acceptable checks if a segment that starts at segSeq and has length segLen is
+// "acceptable" for arriving in a receive window that starts at rcvNxt and ends
+// before rcvAcc, according to the table on page 26 and 69 of RFC 793.
+func Acceptable(segSeq seqnum.Value, segLen seqnum.Size, rcvNxt, rcvAcc seqnum.Value) bool {
+ if rcvNxt == rcvAcc {
+ return segLen == 0 && segSeq == rcvNxt
+ }
+ if segLen == 0 {
+ // rcvWnd is incremented by 1 because that is Linux's behavior despite the
+ // RFC.
+ return segSeq.InRange(rcvNxt, rcvAcc.Add(1))
+ }
+ // Page 70 of RFC 793 allows packets that can be made "acceptable" by trimming
+ // the payload, so we'll accept any payload that overlaps the receieve window.
+ return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThan(rcvAcc)
}
// getSendParams returns the parameters needed by the sender when building
diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go
new file mode 100644
index 000000000..dc02729ce
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rcv_test.go
@@ -0,0 +1,74 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package rcv_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+)
+
+func TestAcceptable(t *testing.T) {
+ for _, tt := range []struct {
+ segSeq seqnum.Value
+ segLen seqnum.Size
+ rcvNxt, rcvAcc seqnum.Value
+ want bool
+ }{
+ // The segment is smaller than the window.
+ {105, 2, 100, 104, false},
+ {105, 2, 101, 105, false},
+ {105, 2, 102, 106, true},
+ {105, 2, 103, 107, true},
+ {105, 2, 104, 108, true},
+ {105, 2, 105, 109, true},
+ {105, 2, 106, 110, true},
+ {105, 2, 107, 111, false},
+
+ // The segment is larger than the window.
+ {105, 4, 103, 105, false},
+ {105, 4, 104, 106, true},
+ {105, 4, 105, 107, true},
+ {105, 4, 106, 108, true},
+ {105, 4, 107, 109, true},
+ {105, 4, 108, 110, true},
+ {105, 4, 109, 111, false},
+ {105, 4, 110, 112, false},
+
+ // The segment has no width.
+ {105, 0, 100, 102, false},
+ {105, 0, 101, 103, false},
+ {105, 0, 102, 104, false},
+ {105, 0, 103, 105, true},
+ {105, 0, 104, 106, true},
+ {105, 0, 105, 107, true},
+ {105, 0, 106, 108, false},
+ {105, 0, 107, 109, false},
+
+ // The receive window has no width.
+ {105, 2, 103, 103, false},
+ {105, 2, 104, 104, false},
+ {105, 2, 105, 105, false},
+ {105, 2, 106, 106, false},
+ {105, 2, 107, 107, false},
+ {105, 2, 108, 108, false},
+ {105, 2, 109, 109, false},
+ } {
+ if got := tcp.Acceptable(tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc); got != tt.want {
+ t.Errorf("tcp.Acceptable(%d, %d, %d, %d) = %t, want %t", tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc, got, tt.want)
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD
index 3ad6994a7..2025ff757 100644
--- a/pkg/tcpip/transport/tcpconntrack/BUILD
+++ b/pkg/tcpip/transport/tcpconntrack/BUILD
@@ -9,6 +9,7 @@ go_library(
deps = [
"//pkg/tcpip/header",
"//pkg/tcpip/seqnum",
+ "//pkg/tcpip/transport/tcp",
],
)
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
index 93712cd45..30d05200f 100644
--- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
@@ -20,6 +20,7 @@ package tcpconntrack
import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
)
// Result is returned when the state of a TCB is updated in response to an
@@ -311,17 +312,7 @@ type stream struct {
// the window is zero, if it's a packet with no payload and sequence number
// equal to una.
func (s *stream) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
- wnd := s.una.Size(s.end)
- if wnd == 0 {
- return segLen == 0 && segSeq == s.una
- }
-
- // Make sure [segSeq, seqSeq+segLen) is non-empty.
- if segLen == 0 {
- segLen = 1
- }
-
- return seqnum.Overlap(s.una, wnd, segSeq, segLen)
+ return tcp.Acceptable(segSeq, segLen, s.una, s.end)
}
// closed determines if the stream has already been closed. This happens when