diff options
author | Bhasker Hariharan <bhaskerh@google.com> | 2021-10-21 13:50:08 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-10-21 13:53:04 -0700 |
commit | 207221ffb27f2010c46468d827f1817432df3960 (patch) | |
tree | f1c22aa2c0fcc729ce56bdfb5eb592175cb4f621 | |
parent | cfcd3eba9f011b73be1f359f6da7af7f2584a089 (diff) |
Add an integration test for istio like redirect.
Updates #6441,#6317
PiperOrigin-RevId: 404872327
-rw-r--r-- | pkg/tcpip/adapters/gonet/gonet.go | 24 | ||||
-rw-r--r-- | pkg/tcpip/header/ipv4.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/BUILD | 22 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/istio_test.go | 365 |
4 files changed, 415 insertions, 5 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 010e2e833..1f2bcaf65 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -19,6 +19,7 @@ import ( "bytes" "context" "errors" + "fmt" "io" "net" "time" @@ -471,9 +472,9 @@ func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtoc return DialContextTCP(context.Background(), s, addr, network) } -// DialContextTCP creates a new TCPConn connected to the specified address -// with the option of adding cancellation and timeouts. -func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { +// DialTCPWithBind creates a new TCPConn connected to the specified +// remoteAddress with its local address bound to localAddr. +func DialTCPWithBind(ctx context.Context, s *stack.Stack, localAddr, remoteAddr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { // Create TCP endpoint, then connect. var wq waiter.Queue ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) @@ -494,7 +495,14 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, default: } - err = ep.Connect(addr) + // Bind before connect if requested. + if localAddr != (tcpip.FullAddress{}) { + if err = ep.Bind(localAddr); err != nil { + return nil, fmt.Errorf("ep.Bind(%+v) = %s", localAddr, err) + } + } + + err = ep.Connect(remoteAddr) if _, ok := err.(*tcpip.ErrConnectStarted); ok { select { case <-ctx.Done(): @@ -510,7 +518,7 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, return nil, &net.OpError{ Op: "connect", Net: "tcp", - Addr: fullToTCPAddr(addr), + Addr: fullToTCPAddr(remoteAddr), Err: errors.New(err.String()), } } @@ -518,6 +526,12 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, return NewTCPConn(&wq, ep), nil } +// DialContextTCP creates a new TCPConn connected to the specified address +// with the option of adding cancellation and timeouts. +func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { + return DialTCPWithBind(ctx, s, tcpip.FullAddress{} /* localAddr */, addr /* remoteAddr */, network) +} + // A UDPConn is a wrapper around a UDP tcpip.Endpoint that implements // net.Conn and net.PacketConn. type UDPConn struct { diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index dcc549c7b..7baaf0d17 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -208,6 +208,15 @@ var IPv4EmptySubnet = func() tcpip.Subnet { return subnet }() +// IPv4LoopbackSubnet is the loopback subnet for IPv4. +var IPv4LoopbackSubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet(tcpip.Address("\x7f\x00\x00\x00"), tcpip.AddressMask("\xff\x00\x00\x00")) + if err != nil { + panic(err) + } + return subnet +}() + // IPVersion returns the version of IP used in the given packet. It returns -1 // if the packet is not large enough to contain the version field. func IPVersion(b []byte) int { diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 7c998eaae..99f4d4d0e 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -143,3 +143,25 @@ go_test( "@com_github_google_go_cmp//cmp:go_default_library", ], ) + +go_test( + name = "istio_test", + size = "small", + srcs = ["istio_test.go"], + deps = [ + "//pkg/context", + "//pkg/rand", + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/adapters/gonet", + "//pkg/tcpip/header", + "//pkg/tcpip/link/loopback", + "//pkg/tcpip/link/pipe", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", + "//pkg/tcpip/transport/tcp", + "@com_github_google_go_cmp//cmp:go_default_library", + ], +) diff --git a/pkg/tcpip/tests/integration/istio_test.go b/pkg/tcpip/tests/integration/istio_test.go new file mode 100644 index 000000000..95d994ef8 --- /dev/null +++ b/pkg/tcpip/tests/integration/istio_test.go @@ -0,0 +1,365 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package istio_test + +import ( + "fmt" + "io" + "net" + "net/http" + "strconv" + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/link/pipe" + "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" +) + +// testContext encapsulates the state required to run tests that simulate +// an istio like environment. +// +// A diagram depicting the setup is shown below. +// +-----------------------------------------------------------------------+ +// | +-------------------------------------------------+ | +// | + ----------+ | + -----------------+ PROXY +----------+ | | +// | | clientEP | | | serverListeningEP|--accepted-> | serverEP |-+ | | +// | + ----------+ | + -----------------+ +----------+ | | | +// | | -------|-------------+ +----------+ | | | +// | | | | | proxyEP |-+ | | +// | +-----redirect | +----------+ | | +// | + ------------+---|------+---+ | +// | | | +// | Local Stack. | | +// +-------------------------------------------------------|---------------+ +// | +// +-----------------------------------------------------------------------+ +// | remoteStack | | +// | +-------------SYN ---------------| | +// | | | | +// | +-------------------|--------------------------------|-_---+ | +// | | + -----------------+ + ----------+ | | | +// | | | remoteListeningEP|--accepted--->| remoteEP |<++ | | +// | | + -----------------+ + ----------+ | | +// | | Remote HTTP Server | | +// | +----------------------------------------------------------+ | +// +-----------------------------------------------------------------------+ +// +type testContext struct { + // localServerListener is the listening port for the server which will proxy + // all traffic to the remote EP. + localServerListener *gonet.TCPListener + + // remoteListenListener is the remote listening endpoint that will receive + // connections from server. + remoteServerListener *gonet.TCPListener + + // localStack is the stack used to create client/server endpoints and + // also the stack on which we install NAT redirect rules. + localStack *stack.Stack + + // remoteStack is the stack that represents a *remote* server. + remoteStack *stack.Stack + + // defaultResponse is the response served by the HTTP server for all GET + defaultResponse []byte + + // requests. wg is used to wait for HTTP server and Proxy to terminate before + // returning from cleanup. + wg sync.WaitGroup +} + +func (ctx *testContext) cleanup() { + ctx.localServerListener.Close() + ctx.localStack.Close() + ctx.remoteServerListener.Close() + ctx.remoteStack.Close() + ctx.wg.Wait() +} + +const ( + localServerPort = 8080 + remoteServerPort = 9090 +) + +var ( + localIPv4Addr1 = testutil.MustParse4("10.0.0.1") + localIPv4Addr2 = testutil.MustParse4("10.0.0.2") + loopbackIPv4Addr = testutil.MustParse4("127.0.0.1") + remoteIPv4Addr1 = testutil.MustParse4("10.0.0.3") +) + +func newTestContext(t *testing.T) *testContext { + t.Helper() + localNIC, remoteNIC := pipe.New("" /* linkAddr1 */, "" /* linkAddr2 */) + + localStack := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + HandleLocal: true, + }) + + remoteStack := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + HandleLocal: true, + }) + + // Add loopback NIC. We need a loopback NIC as NAT redirect rule redirect to + // loopback address + specified port. + loopbackNIC := loopback.New() + const loopbackNICID = tcpip.NICID(1) + if err := localStack.CreateNIC(loopbackNICID, sniffer.New(loopbackNIC)); err != nil { + t.Fatalf("localStack.CreateNIC(%d, _): %s", loopbackNICID, err) + } + loopbackAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: loopbackIPv4Addr.WithPrefix(), + } + if err := localStack.AddProtocolAddress(loopbackNICID, loopbackAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("localStack.AddProtocolAddress(%d, %+v, {}): %s", loopbackNICID, loopbackAddr, err) + } + + // Create linked NICs that connects the local and remote stack. + const localNICID = tcpip.NICID(2) + const remoteNICID = tcpip.NICID(3) + if err := localStack.CreateNIC(localNICID, sniffer.New(localNIC)); err != nil { + t.Fatalf("localStack.CreateNIC(%d, _): %s", localNICID, err) + } + if err := remoteStack.CreateNIC(remoteNICID, sniffer.New(remoteNIC)); err != nil { + t.Fatalf("remoteStack.CreateNIC(%d, _): %s", remoteNICID, err) + } + + for _, addr := range []tcpip.Address{localIPv4Addr1, localIPv4Addr2} { + localProtocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: addr.WithPrefix(), + } + if err := localStack.AddProtocolAddress(localNICID, localProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("localStack.AddProtocolAddress(%d, %+v, {}): %s", localNICID, localProtocolAddr, err) + } + } + + remoteProtocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: remoteIPv4Addr1.WithPrefix(), + } + if err := remoteStack.AddProtocolAddress(remoteNICID, remoteProtocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("remoteStack.AddProtocolAddress(%d, %+v, {}): %s", remoteNICID, remoteProtocolAddr, err) + } + + // Setup route table for local and remote stacks. + localStack.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4LoopbackSubnet, + NIC: loopbackNICID, + }, + { + Destination: header.IPv4EmptySubnet, + NIC: localNICID, + }, + }) + remoteStack.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: remoteNICID, + }, + }) + + const netProto = ipv4.ProtocolNumber + localServerAddress := tcpip.FullAddress{ + Port: localServerPort, + } + + localServerListener, err := gonet.ListenTCP(localStack, localServerAddress, netProto) + if err != nil { + t.Fatalf("gonet.ListenTCP(_, %+v, %d) = %s", localServerAddress, netProto, err) + } + + remoteServerAddress := tcpip.FullAddress{ + Port: remoteServerPort, + } + remoteServerListener, err := gonet.ListenTCP(remoteStack, remoteServerAddress, netProto) + if err != nil { + t.Fatalf("gonet.ListenTCP(_, %+v, %d) = %s", remoteServerAddress, netProto, err) + } + + // Initialize a random default response served by the HTTP server. + defaultResponse := make([]byte, 512<<10) + if _, err := rand.Read(defaultResponse); err != nil { + t.Fatalf("rand.Read(buf) failed: %s", err) + } + + tc := &testContext{ + localServerListener: localServerListener, + remoteServerListener: remoteServerListener, + localStack: localStack, + remoteStack: remoteStack, + defaultResponse: defaultResponse, + } + + tc.startServers(t) + return tc +} + +func (ctx *testContext) startServers(t *testing.T) { + ctx.wg.Add(1) + go func() { + defer ctx.wg.Done() + ctx.startHTTPServer() + }() + ctx.wg.Add(1) + go func() { + defer ctx.wg.Done() + ctx.startTCPProxyServer(t) + }() +} + +func (ctx *testContext) startTCPProxyServer(t *testing.T) { + t.Helper() + for { + conn, err := ctx.localServerListener.Accept() + if err != nil { + t.Logf("terminating local proxy server: %s", err) + return + } + // Start a goroutine to handle this inbound connection. + go func() { + remoteServerAddr := tcpip.FullAddress{ + Addr: remoteIPv4Addr1, + Port: remoteServerPort, + } + localServerAddr := tcpip.FullAddress{ + Addr: localIPv4Addr2, + } + serverConn, err := gonet.DialTCPWithBind(context.Background(), ctx.localStack, localServerAddr, remoteServerAddr, ipv4.ProtocolNumber) + if err != nil { + t.Logf("gonet.DialTCP(_, %+v, %d) = %s", remoteServerAddr, ipv4.ProtocolNumber, err) + return + } + proxy(conn, serverConn) + t.Logf("proxying completed") + }() + } +} + +// proxy transparently proxies the TCP payload from conn1 to conn2 +// and vice versa. +func proxy(conn1, conn2 net.Conn) { + var wg sync.WaitGroup + wg.Add(1) + go func() { + io.Copy(conn2, conn1) + conn1.Close() + conn2.Close() + }() + wg.Add(1) + go func() { + io.Copy(conn1, conn2) + conn1.Close() + conn2.Close() + }() + wg.Wait() +} + +func (ctx *testContext) startHTTPServer() { + handlerFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(ctx.defaultResponse)) + }) + s := &http.Server{ + Handler: handlerFunc, + } + s.Serve(ctx.remoteServerListener) +} + +func TestOutboundNATRedirect(t *testing.T) { + ctx := newTestContext(t) + defer ctx.cleanup() + + // Install an IPTable rule to redirect all TCP traffic with the sourceIP of + // localIPv4Addr1 to the tcp proxy port. + ipt := ctx.localStack.IPTables() + tbl := ipt.GetTable(stack.NATID, false /* ipv6 */) + ruleIdx := tbl.BuiltinChains[stack.Output] + tbl.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ + Protocol: tcp.ProtocolNumber, + CheckProtocol: true, + Src: localIPv4Addr1, + SrcMask: tcpip.Address("\xff\xff\xff\xff"), + } + tbl.Rules[ruleIdx].Target = &stack.RedirectTarget{ + Port: localServerPort, + NetworkProtocol: ipv4.ProtocolNumber, + } + tbl.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.NATID, tbl, false /* ipv6 */); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, false): %s", stack.NATID, err) + } + + dialFunc := func(protocol, address string) (net.Conn, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, fmt.Errorf("unable to parse address: %s, err: %s", address, err) + } + + remoteServerIP := net.ParseIP(host) + remoteServerPort, err := strconv.Atoi(port) + if err != nil { + return nil, fmt.Errorf("unable to parse port from string %s, err: %s", port, err) + } + remoteAddress := tcpip.FullAddress{ + Addr: tcpip.Address(remoteServerIP.To4()), + Port: uint16(remoteServerPort), + } + + // Dial with an explicit source address bound so that the redirect rule will + // be able to correctly redirect these packets. + localAddr := tcpip.FullAddress{Addr: localIPv4Addr1} + return gonet.DialTCPWithBind(context.Background(), ctx.localStack, localAddr, remoteAddress, ipv4.ProtocolNumber) + } + + httpClient := &http.Client{ + Transport: &http.Transport{ + Dial: dialFunc, + }, + } + + serverURL := fmt.Sprintf("http://[%s]:%d/", net.IP(remoteIPv4Addr1), remoteServerPort) + response, err := httpClient.Get(serverURL) + if err != nil { + t.Fatalf("httpClient.Get(\"/\") failed: %s", err) + } + if got, want := response.StatusCode, http.StatusOK; got != want { + t.Fatalf("unexpected status code got: %d, want: %d", got, want) + } + body, err := io.ReadAll(response.Body) + if err != nil { + t.Fatalf("io.ReadAll(response.Body) failed: %s", err) + } + response.Body.Close() + if diff := cmp.Diff(body, ctx.defaultResponse); diff != "" { + t.Fatalf("unexpected response (-want +got): \n %s", diff) + } +} |