summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorBhasker Hariharan <bhaskerh@google.com>2021-10-21 13:50:08 -0700
committergVisor bot <gvisor-bot@google.com>2021-10-21 13:53:04 -0700
commit207221ffb27f2010c46468d827f1817432df3960 (patch)
treef1c22aa2c0fcc729ce56bdfb5eb592175cb4f621
parentcfcd3eba9f011b73be1f359f6da7af7f2584a089 (diff)
Add an integration test for istio like redirect.
Updates #6441,#6317 PiperOrigin-RevId: 404872327
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go24
-rw-r--r--pkg/tcpip/header/ipv4.go9
-rw-r--r--pkg/tcpip/tests/integration/BUILD22
-rw-r--r--pkg/tcpip/tests/integration/istio_test.go365
4 files changed, 415 insertions, 5 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 010e2e833..1f2bcaf65 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -19,6 +19,7 @@ import (
"bytes"
"context"
"errors"
+ "fmt"
"io"
"net"
"time"
@@ -471,9 +472,9 @@ func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtoc
return DialContextTCP(context.Background(), s, addr, network)
}
-// DialContextTCP creates a new TCPConn connected to the specified address
-// with the option of adding cancellation and timeouts.
-func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
+// DialTCPWithBind creates a new TCPConn connected to the specified
+// remoteAddress with its local address bound to localAddr.
+func DialTCPWithBind(ctx context.Context, s *stack.Stack, localAddr, remoteAddr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
// Create TCP endpoint, then connect.
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
@@ -494,7 +495,14 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress,
default:
}
- err = ep.Connect(addr)
+ // Bind before connect if requested.
+ if localAddr != (tcpip.FullAddress{}) {
+ if err = ep.Bind(localAddr); err != nil {
+ return nil, fmt.Errorf("ep.Bind(%+v) = %s", localAddr, err)
+ }
+ }
+
+ err = ep.Connect(remoteAddr)
if _, ok := err.(*tcpip.ErrConnectStarted); ok {
select {
case <-ctx.Done():
@@ -510,7 +518,7 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress,
return nil, &net.OpError{
Op: "connect",
Net: "tcp",
- Addr: fullToTCPAddr(addr),
+ Addr: fullToTCPAddr(remoteAddr),
Err: errors.New(err.String()),
}
}
@@ -518,6 +526,12 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress,
return NewTCPConn(&wq, ep), nil
}
+// DialContextTCP creates a new TCPConn connected to the specified address
+// with the option of adding cancellation and timeouts.
+func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
+ return DialTCPWithBind(ctx, s, tcpip.FullAddress{} /* localAddr */, addr /* remoteAddr */, network)
+}
+
// A UDPConn is a wrapper around a UDP tcpip.Endpoint that implements
// net.Conn and net.PacketConn.
type UDPConn struct {
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index dcc549c7b..7baaf0d17 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -208,6 +208,15 @@ var IPv4EmptySubnet = func() tcpip.Subnet {
return subnet
}()
+// IPv4LoopbackSubnet is the loopback subnet for IPv4.
+var IPv4LoopbackSubnet = func() tcpip.Subnet {
+ subnet, err := tcpip.NewSubnet(tcpip.Address("\x7f\x00\x00\x00"), tcpip.AddressMask("\xff\x00\x00\x00"))
+ if err != nil {
+ panic(err)
+ }
+ return subnet
+}()
+
// IPVersion returns the version of IP used in the given packet. It returns -1
// if the packet is not large enough to contain the version field.
func IPVersion(b []byte) int {
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 7c998eaae..99f4d4d0e 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -143,3 +143,25 @@ go_test(
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
+
+go_test(
+ name = "istio_test",
+ size = "small",
+ srcs = ["istio_test.go"],
+ deps = [
+ "//pkg/context",
+ "//pkg/rand",
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/adapters/gonet",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/link/pipe",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
+ "//pkg/tcpip/transport/tcp",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/tests/integration/istio_test.go b/pkg/tcpip/tests/integration/istio_test.go
new file mode 100644
index 000000000..95d994ef8
--- /dev/null
+++ b/pkg/tcpip/tests/integration/istio_test.go
@@ -0,0 +1,365 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package istio_test
+
+import (
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "strconv"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/link/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+)
+
+// testContext encapsulates the state required to run tests that simulate
+// an istio like environment.
+//
+// A diagram depicting the setup is shown below.
+// +-----------------------------------------------------------------------+
+// | +-------------------------------------------------+ |
+// | + ----------+ | + -----------------+ PROXY +----------+ | |
+// | | clientEP | | | serverListeningEP|--accepted-> | serverEP |-+ | |
+// | + ----------+ | + -----------------+ +----------+ | | |
+// | | -------|-------------+ +----------+ | | |
+// | | | | | proxyEP |-+ | |
+// | +-----redirect | +----------+ | |
+// | + ------------+---|------+---+ |
+// | | |
+// | Local Stack. | |
+// +-------------------------------------------------------|---------------+
+// |
+// +-----------------------------------------------------------------------+
+// | remoteStack | |
+// | +-------------SYN ---------------| |
+// | | | |
+// | +-------------------|--------------------------------|-_---+ |
+// | | + -----------------+ + ----------+ | | |
+// | | | remoteListeningEP|--accepted--->| remoteEP |<++ | |
+// | | + -----------------+ + ----------+ | |
+// | | Remote HTTP Server | |
+// | +----------------------------------------------------------+ |
+// +-----------------------------------------------------------------------+
+//
+type testContext struct {
+ // localServerListener is the listening port for the server which will proxy
+ // all traffic to the remote EP.
+ localServerListener *gonet.TCPListener
+
+ // remoteListenListener is the remote listening endpoint that will receive
+ // connections from server.
+ remoteServerListener *gonet.TCPListener
+
+ // localStack is the stack used to create client/server endpoints and
+ // also the stack on which we install NAT redirect rules.
+ localStack *stack.Stack
+
+ // remoteStack is the stack that represents a *remote* server.
+ remoteStack *stack.Stack
+
+ // defaultResponse is the response served by the HTTP server for all GET
+ defaultResponse []byte
+
+ // requests. wg is used to wait for HTTP server and Proxy to terminate before
+ // returning from cleanup.
+ wg sync.WaitGroup
+}
+
+func (ctx *testContext) cleanup() {
+ ctx.localServerListener.Close()
+ ctx.localStack.Close()
+ ctx.remoteServerListener.Close()
+ ctx.remoteStack.Close()
+ ctx.wg.Wait()
+}
+
+const (
+ localServerPort = 8080
+ remoteServerPort = 9090
+)
+
+var (
+ localIPv4Addr1 = testutil.MustParse4("10.0.0.1")
+ localIPv4Addr2 = testutil.MustParse4("10.0.0.2")
+ loopbackIPv4Addr = testutil.MustParse4("127.0.0.1")
+ remoteIPv4Addr1 = testutil.MustParse4("10.0.0.3")
+)
+
+func newTestContext(t *testing.T) *testContext {
+ t.Helper()
+ localNIC, remoteNIC := pipe.New("" /* linkAddr1 */, "" /* linkAddr2 */)
+
+ localStack := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ HandleLocal: true,
+ })
+
+ remoteStack := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ HandleLocal: true,
+ })
+
+ // Add loopback NIC. We need a loopback NIC as NAT redirect rule redirect to
+ // loopback address + specified port.
+ loopbackNIC := loopback.New()
+ const loopbackNICID = tcpip.NICID(1)
+ if err := localStack.CreateNIC(loopbackNICID, sniffer.New(loopbackNIC)); err != nil {
+ t.Fatalf("localStack.CreateNIC(%d, _): %s", loopbackNICID, err)
+ }
+ loopbackAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: loopbackIPv4Addr.WithPrefix(),
+ }
+ if err := localStack.AddProtocolAddress(loopbackNICID, loopbackAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("localStack.AddProtocolAddress(%d, %+v, {}): %s", loopbackNICID, loopbackAddr, err)
+ }
+
+ // Create linked NICs that connects the local and remote stack.
+ const localNICID = tcpip.NICID(2)
+ const remoteNICID = tcpip.NICID(3)
+ if err := localStack.CreateNIC(localNICID, sniffer.New(localNIC)); err != nil {
+ t.Fatalf("localStack.CreateNIC(%d, _): %s", localNICID, err)
+ }
+ if err := remoteStack.CreateNIC(remoteNICID, sniffer.New(remoteNIC)); err != nil {
+ t.Fatalf("remoteStack.CreateNIC(%d, _): %s", remoteNICID, err)
+ }
+
+ for _, addr := range []tcpip.Address{localIPv4Addr1, localIPv4Addr2} {
+ localProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: addr.WithPrefix(),
+ }
+ if err := localStack.AddProtocolAddress(localNICID, localProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("localStack.AddProtocolAddress(%d, %+v, {}): %s", localNICID, localProtocolAddr, err)
+ }
+ }
+
+ remoteProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: remoteIPv4Addr1.WithPrefix(),
+ }
+ if err := remoteStack.AddProtocolAddress(remoteNICID, remoteProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("remoteStack.AddProtocolAddress(%d, %+v, {}): %s", remoteNICID, remoteProtocolAddr, err)
+ }
+
+ // Setup route table for local and remote stacks.
+ localStack.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4LoopbackSubnet,
+ NIC: loopbackNICID,
+ },
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: localNICID,
+ },
+ })
+ remoteStack.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: remoteNICID,
+ },
+ })
+
+ const netProto = ipv4.ProtocolNumber
+ localServerAddress := tcpip.FullAddress{
+ Port: localServerPort,
+ }
+
+ localServerListener, err := gonet.ListenTCP(localStack, localServerAddress, netProto)
+ if err != nil {
+ t.Fatalf("gonet.ListenTCP(_, %+v, %d) = %s", localServerAddress, netProto, err)
+ }
+
+ remoteServerAddress := tcpip.FullAddress{
+ Port: remoteServerPort,
+ }
+ remoteServerListener, err := gonet.ListenTCP(remoteStack, remoteServerAddress, netProto)
+ if err != nil {
+ t.Fatalf("gonet.ListenTCP(_, %+v, %d) = %s", remoteServerAddress, netProto, err)
+ }
+
+ // Initialize a random default response served by the HTTP server.
+ defaultResponse := make([]byte, 512<<10)
+ if _, err := rand.Read(defaultResponse); err != nil {
+ t.Fatalf("rand.Read(buf) failed: %s", err)
+ }
+
+ tc := &testContext{
+ localServerListener: localServerListener,
+ remoteServerListener: remoteServerListener,
+ localStack: localStack,
+ remoteStack: remoteStack,
+ defaultResponse: defaultResponse,
+ }
+
+ tc.startServers(t)
+ return tc
+}
+
+func (ctx *testContext) startServers(t *testing.T) {
+ ctx.wg.Add(1)
+ go func() {
+ defer ctx.wg.Done()
+ ctx.startHTTPServer()
+ }()
+ ctx.wg.Add(1)
+ go func() {
+ defer ctx.wg.Done()
+ ctx.startTCPProxyServer(t)
+ }()
+}
+
+func (ctx *testContext) startTCPProxyServer(t *testing.T) {
+ t.Helper()
+ for {
+ conn, err := ctx.localServerListener.Accept()
+ if err != nil {
+ t.Logf("terminating local proxy server: %s", err)
+ return
+ }
+ // Start a goroutine to handle this inbound connection.
+ go func() {
+ remoteServerAddr := tcpip.FullAddress{
+ Addr: remoteIPv4Addr1,
+ Port: remoteServerPort,
+ }
+ localServerAddr := tcpip.FullAddress{
+ Addr: localIPv4Addr2,
+ }
+ serverConn, err := gonet.DialTCPWithBind(context.Background(), ctx.localStack, localServerAddr, remoteServerAddr, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Logf("gonet.DialTCP(_, %+v, %d) = %s", remoteServerAddr, ipv4.ProtocolNumber, err)
+ return
+ }
+ proxy(conn, serverConn)
+ t.Logf("proxying completed")
+ }()
+ }
+}
+
+// proxy transparently proxies the TCP payload from conn1 to conn2
+// and vice versa.
+func proxy(conn1, conn2 net.Conn) {
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ io.Copy(conn2, conn1)
+ conn1.Close()
+ conn2.Close()
+ }()
+ wg.Add(1)
+ go func() {
+ io.Copy(conn1, conn2)
+ conn1.Close()
+ conn2.Close()
+ }()
+ wg.Wait()
+}
+
+func (ctx *testContext) startHTTPServer() {
+ handlerFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(ctx.defaultResponse))
+ })
+ s := &http.Server{
+ Handler: handlerFunc,
+ }
+ s.Serve(ctx.remoteServerListener)
+}
+
+func TestOutboundNATRedirect(t *testing.T) {
+ ctx := newTestContext(t)
+ defer ctx.cleanup()
+
+ // Install an IPTable rule to redirect all TCP traffic with the sourceIP of
+ // localIPv4Addr1 to the tcp proxy port.
+ ipt := ctx.localStack.IPTables()
+ tbl := ipt.GetTable(stack.NATID, false /* ipv6 */)
+ ruleIdx := tbl.BuiltinChains[stack.Output]
+ tbl.Rules[ruleIdx].Filter = stack.IPHeaderFilter{
+ Protocol: tcp.ProtocolNumber,
+ CheckProtocol: true,
+ Src: localIPv4Addr1,
+ SrcMask: tcpip.Address("\xff\xff\xff\xff"),
+ }
+ tbl.Rules[ruleIdx].Target = &stack.RedirectTarget{
+ Port: localServerPort,
+ NetworkProtocol: ipv4.ProtocolNumber,
+ }
+ tbl.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, tbl, false /* ipv6 */); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, false): %s", stack.NATID, err)
+ }
+
+ dialFunc := func(protocol, address string) (net.Conn, error) {
+ host, port, err := net.SplitHostPort(address)
+ if err != nil {
+ return nil, fmt.Errorf("unable to parse address: %s, err: %s", address, err)
+ }
+
+ remoteServerIP := net.ParseIP(host)
+ remoteServerPort, err := strconv.Atoi(port)
+ if err != nil {
+ return nil, fmt.Errorf("unable to parse port from string %s, err: %s", port, err)
+ }
+ remoteAddress := tcpip.FullAddress{
+ Addr: tcpip.Address(remoteServerIP.To4()),
+ Port: uint16(remoteServerPort),
+ }
+
+ // Dial with an explicit source address bound so that the redirect rule will
+ // be able to correctly redirect these packets.
+ localAddr := tcpip.FullAddress{Addr: localIPv4Addr1}
+ return gonet.DialTCPWithBind(context.Background(), ctx.localStack, localAddr, remoteAddress, ipv4.ProtocolNumber)
+ }
+
+ httpClient := &http.Client{
+ Transport: &http.Transport{
+ Dial: dialFunc,
+ },
+ }
+
+ serverURL := fmt.Sprintf("http://[%s]:%d/", net.IP(remoteIPv4Addr1), remoteServerPort)
+ response, err := httpClient.Get(serverURL)
+ if err != nil {
+ t.Fatalf("httpClient.Get(\"/\") failed: %s", err)
+ }
+ if got, want := response.StatusCode, http.StatusOK; got != want {
+ t.Fatalf("unexpected status code got: %d, want: %d", got, want)
+ }
+ body, err := io.ReadAll(response.Body)
+ if err != nil {
+ t.Fatalf("io.ReadAll(response.Body) failed: %s", err)
+ }
+ response.Body.Close()
+ if diff := cmp.Diff(body, ctx.defaultResponse); diff != "" {
+ t.Fatalf("unexpected response (-want +got): \n %s", diff)
+ }
+}