diff options
-rw-r--r-- | pkg/test/testutil/testutil.go | 5 | ||||
-rw-r--r-- | test/iptables/filter_input.go | 253 | ||||
-rw-r--r-- | test/iptables/filter_output.go | 252 | ||||
-rw-r--r-- | test/iptables/iptables.go | 61 | ||||
-rw-r--r-- | test/iptables/iptables_test.go | 68 | ||||
-rw-r--r-- | test/iptables/iptables_util.go | 85 | ||||
-rw-r--r-- | test/iptables/nat.go | 225 | ||||
-rw-r--r-- | test/iptables/runner/main.go | 5 |
8 files changed, 571 insertions, 383 deletions
diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go index 64c292698..1580527b5 100644 --- a/pkg/test/testutil/testutil.go +++ b/pkg/test/testutil/testutil.go @@ -316,6 +316,11 @@ func Copy(src, dst string) error { func Poll(cb func() error, timeout time.Duration) error { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() + return PollContext(ctx, cb) +} + +// PollContext is like Poll, but takes a context instead of a timeout. +func PollContext(ctx context.Context, cb func() error) error { b := backoff.WithContext(backoff.NewConstantBackOff(100*time.Millisecond), ctx) return backoff.Retry(cb, b) } diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go index 5737ee317..b45d448b8 100644 --- a/test/iptables/filter_input.go +++ b/test/iptables/filter_input.go @@ -15,6 +15,7 @@ package iptables import ( + "context" "errors" "fmt" "net" @@ -53,7 +54,7 @@ func init() { } // FilterInputDropUDP tests that we can drop UDP traffic. -type FilterInputDropUDP struct{} +type FilterInputDropUDP struct{ containerCase } // Name implements TestCase.Name. func (FilterInputDropUDP) Name() string { @@ -61,15 +62,17 @@ func (FilterInputDropUDP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropUDP) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputDropUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil { return err } // Listen for UDP packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) - } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + } else if !errors.Is(err, context.DeadlineExceeded) { return fmt.Errorf("error reading: %v", err) } @@ -79,12 +82,12 @@ func (FilterInputDropUDP) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropUDP) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, dropPort, sendloopDuration) +func (FilterInputDropUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) } // FilterInputDropOnlyUDP tests that "-p udp -j DROP" only affects UDP traffic. -type FilterInputDropOnlyUDP struct{} +type FilterInputDropOnlyUDP struct{ baseCase } // Name implements TestCase.Name. func (FilterInputDropOnlyUDP) Name() string { @@ -92,13 +95,13 @@ func (FilterInputDropOnlyUDP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropOnlyUDP) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputDropOnlyUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil { return err } // Listen for a TCP connection, which should be allowed. - if err := listenTCP(acceptPort, sendloopDuration); err != nil { + if err := listenTCP(ctx, acceptPort); err != nil { return fmt.Errorf("failed to establish a connection %v", err) } @@ -106,14 +109,14 @@ func (FilterInputDropOnlyUDP) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropOnlyUDP) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterInputDropOnlyUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Try to establish a TCP connection with the container, which should // succeed. - return connectTCP(ip, acceptPort, sendloopDuration) + return connectTCP(ctx, ip, acceptPort) } // FilterInputDropUDPPort tests that we can drop UDP traffic by port. -type FilterInputDropUDPPort struct{} +type FilterInputDropUDPPort struct{ containerCase } // Name implements TestCase.Name. func (FilterInputDropUDPPort) Name() string { @@ -121,15 +124,17 @@ func (FilterInputDropUDPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropUDPPort) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputDropUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { return err } // Listen for UDP packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) - } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + } else if !errors.Is(err, context.DeadlineExceeded) { return fmt.Errorf("error reading: %v", err) } @@ -139,13 +144,13 @@ func (FilterInputDropUDPPort) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropUDPPort) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, dropPort, sendloopDuration) +func (FilterInputDropUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) } // FilterInputDropDifferentUDPPort tests that dropping traffic for a single UDP port // doesn't drop packets on other ports. -type FilterInputDropDifferentUDPPort struct{} +type FilterInputDropDifferentUDPPort struct{ containerCase } // Name implements TestCase.Name. func (FilterInputDropDifferentUDPPort) Name() string { @@ -153,13 +158,13 @@ func (FilterInputDropDifferentUDPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropDifferentUDPPort) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputDropDifferentUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { return err } // Listen for UDP packets on another port. - if err := listenUDP(acceptPort, sendloopDuration); err != nil { + if err := listenUDP(ctx, acceptPort); err != nil { return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", acceptPort, err) } @@ -167,12 +172,12 @@ func (FilterInputDropDifferentUDPPort) ContainerAction(ip net.IP, ipv6 bool) err } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropDifferentUDPPort) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputDropDifferentUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputDropTCPDestPort tests that connections are not accepted on specified source ports. -type FilterInputDropTCPDestPort struct{} +type FilterInputDropTCPDestPort struct{ baseCase } // Name implements TestCase.Name. func (FilterInputDropTCPDestPort) Name() string { @@ -180,33 +185,36 @@ func (FilterInputDropTCPDestPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropTCPDestPort) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputDropTCPDestPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "INPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { return err } // Listen for TCP packets on drop port. - if err := listenTCP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, dropPort); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropTCPDestPort) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterInputDropTCPDestPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Ensure we cannot connect to the container. - for start := time.Now(); time.Since(start) < sendloopDuration; { - if err := connectTCP(ip, dropPort, sendloopDuration-time.Since(start)); err == nil { - return fmt.Errorf("expected not to connect, but was able to connect on port %d", dropPort) - } + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, dropPort); err == nil { + return fmt.Errorf("expected not to connect, but was able to connect on port %d", dropPort) } - return nil } // FilterInputDropTCPSrcPort tests that connections are not accepted on specified source ports. -type FilterInputDropTCPSrcPort struct{} +type FilterInputDropTCPSrcPort struct{ baseCase } // Name implements TestCase.Name. func (FilterInputDropTCPSrcPort) Name() string { @@ -214,34 +222,37 @@ func (FilterInputDropTCPSrcPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropTCPSrcPort) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputDropTCPSrcPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Drop anything from an ephemeral port. if err := filterTable(ipv6, "-A", "INPUT", "-p", "tcp", "-m", "tcp", "--sport", "1024:65535", "-j", "DROP"); err != nil { return err } // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but was", dropPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropTCPSrcPort) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterInputDropTCPSrcPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Ensure we cannot connect to the container. - for start := time.Now(); time.Since(start) < sendloopDuration; { - if err := connectTCP(ip, acceptPort, sendloopDuration-time.Since(start)); err == nil { - return fmt.Errorf("expected not to connect, but was able to connect on port %d", acceptPort) - } + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, dropPort); err == nil { + return fmt.Errorf("expected not to connect, but was able to connect on port %d", acceptPort) } - return nil } // FilterInputDropAll tests that we can drop all traffic to the INPUT chain. -type FilterInputDropAll struct{} +type FilterInputDropAll struct{ containerCase } // Name implements TestCase.Name. func (FilterInputDropAll) Name() string { @@ -249,15 +260,17 @@ func (FilterInputDropAll) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropAll) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputDropAll) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "INPUT", "-j", "DROP"); err != nil { return err } // Listen for all packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { return fmt.Errorf("packets should have been dropped, but got a packet") - } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + } else if !errors.Is(err, context.DeadlineExceeded) { return fmt.Errorf("error reading: %v", err) } @@ -267,15 +280,15 @@ func (FilterInputDropAll) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropAll) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, dropPort, sendloopDuration) +func (FilterInputDropAll) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) } // FilterInputMultiUDPRules verifies that multiple UDP rules are applied // correctly. This has the added benefit of testing whether we're serializing // rules correctly -- if we do it incorrectly, the iptables tool will // misunderstand and save the wrong tables. -type FilterInputMultiUDPRules struct{} +type FilterInputMultiUDPRules struct{ baseCase } // Name implements TestCase.Name. func (FilterInputMultiUDPRules) Name() string { @@ -283,7 +296,7 @@ func (FilterInputMultiUDPRules) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputMultiUDPRules) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputMultiUDPRules) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"}, {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", acceptPort), "-j", "ACCEPT"}, @@ -293,14 +306,14 @@ func (FilterInputMultiUDPRules) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputMultiUDPRules) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterInputMultiUDPRules) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // FilterInputRequireProtocolUDP checks that "-m udp" requires "-p udp" to be // specified. -type FilterInputRequireProtocolUDP struct{} +type FilterInputRequireProtocolUDP struct{ baseCase } // Name implements TestCase.Name. func (FilterInputRequireProtocolUDP) Name() string { @@ -308,20 +321,20 @@ func (FilterInputRequireProtocolUDP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputRequireProtocolUDP) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputRequireProtocolUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "INPUT", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err == nil { return errors.New("expected iptables to fail with out \"-p udp\", but succeeded") } return nil } -func (FilterInputRequireProtocolUDP) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterInputRequireProtocolUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // FilterInputCreateUserChain tests chain creation. -type FilterInputCreateUserChain struct{} +type FilterInputCreateUserChain struct{ baseCase } // Name implements TestCase.Name. func (FilterInputCreateUserChain) Name() string { @@ -329,7 +342,7 @@ func (FilterInputCreateUserChain) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputCreateUserChain) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputCreateUserChain) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ // Create a chain. {"-N", chainName}, @@ -340,13 +353,13 @@ func (FilterInputCreateUserChain) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputCreateUserChain) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterInputCreateUserChain) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // FilterInputDefaultPolicyAccept tests the default ACCEPT policy. -type FilterInputDefaultPolicyAccept struct{} +type FilterInputDefaultPolicyAccept struct{ containerCase } // Name implements TestCase.Name. func (FilterInputDefaultPolicyAccept) Name() string { @@ -354,21 +367,21 @@ func (FilterInputDefaultPolicyAccept) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDefaultPolicyAccept) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputDefaultPolicyAccept) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Set the default policy to accept, then receive a packet. if err := filterTable(ipv6, "-P", "INPUT", "ACCEPT"); err != nil { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputDefaultPolicyAccept) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputDefaultPolicyAccept) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputDefaultPolicyDrop tests the default DROP policy. -type FilterInputDefaultPolicyDrop struct{} +type FilterInputDefaultPolicyDrop struct{ containerCase } // Name implements TestCase.Name. func (FilterInputDefaultPolicyDrop) Name() string { @@ -376,15 +389,17 @@ func (FilterInputDefaultPolicyDrop) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDefaultPolicyDrop) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputDefaultPolicyDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-P", "INPUT", "DROP"); err != nil { return err } // Listen for UDP packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) - } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + } else if !errors.Is(err, context.DeadlineExceeded) { return fmt.Errorf("error reading: %v", err) } @@ -394,13 +409,13 @@ func (FilterInputDefaultPolicyDrop) ContainerAction(ip net.IP, ipv6 bool) error } // LocalAction implements TestCase.LocalAction. -func (FilterInputDefaultPolicyDrop) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputDefaultPolicyDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputReturnUnderflow tests that -j RETURN in a built-in chain causes // the underflow rule (i.e. default policy) to be executed. -type FilterInputReturnUnderflow struct{} +type FilterInputReturnUnderflow struct{ containerCase } // Name implements TestCase.Name. func (FilterInputReturnUnderflow) Name() string { @@ -408,7 +423,7 @@ func (FilterInputReturnUnderflow) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputReturnUnderflow) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputReturnUnderflow) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Add a RETURN rule followed by an unconditional accept, and set the // default policy to DROP. rules := [][]string{ @@ -422,16 +437,16 @@ func (FilterInputReturnUnderflow) ContainerAction(ip net.IP, ipv6 bool) error { // We should receive packets, as the RETURN rule will trigger the default // ACCEPT policy. - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputReturnUnderflow) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputReturnUnderflow) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputSerializeJump verifies that we can serialize jumps. -type FilterInputSerializeJump struct{} +type FilterInputSerializeJump struct{ baseCase } // Name implements TestCase.Name. func (FilterInputSerializeJump) Name() string { @@ -439,7 +454,7 @@ func (FilterInputSerializeJump) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputSerializeJump) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputSerializeJump) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Write a JUMP rule, the serialize it with `-L`. rules := [][]string{ {"-N", chainName}, @@ -450,13 +465,13 @@ func (FilterInputSerializeJump) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputSerializeJump) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterInputSerializeJump) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // FilterInputJumpBasic jumps to a chain and executes a rule there. -type FilterInputJumpBasic struct{} +type FilterInputJumpBasic struct{ containerCase } // Name implements TestCase.Name. func (FilterInputJumpBasic) Name() string { @@ -464,7 +479,7 @@ func (FilterInputJumpBasic) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpBasic) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputJumpBasic) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-P", "INPUT", "DROP"}, {"-N", chainName}, @@ -476,16 +491,16 @@ func (FilterInputJumpBasic) ContainerAction(ip net.IP, ipv6 bool) error { } // Listen for UDP packets on acceptPort. - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputJumpBasic) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputJumpBasic) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputJumpReturn jumps, returns, and executes a rule. -type FilterInputJumpReturn struct{} +type FilterInputJumpReturn struct{ containerCase } // Name implements TestCase.Name. func (FilterInputJumpReturn) Name() string { @@ -493,7 +508,7 @@ func (FilterInputJumpReturn) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpReturn) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputJumpReturn) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-N", chainName}, {"-P", "INPUT", "ACCEPT"}, @@ -506,16 +521,16 @@ func (FilterInputJumpReturn) ContainerAction(ip net.IP, ipv6 bool) error { } // Listen for UDP packets on acceptPort. - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputJumpReturn) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputJumpReturn) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputJumpReturnDrop jumps to a chain, returns, and DROPs packets. -type FilterInputJumpReturnDrop struct{} +type FilterInputJumpReturnDrop struct{ containerCase } // Name implements TestCase.Name. func (FilterInputJumpReturnDrop) Name() string { @@ -523,7 +538,7 @@ func (FilterInputJumpReturnDrop) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpReturnDrop) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputJumpReturnDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-N", chainName}, {"-A", "INPUT", "-j", chainName}, @@ -535,9 +550,11 @@ func (FilterInputJumpReturnDrop) ContainerAction(ip net.IP, ipv6 bool) error { } // Listen for UDP packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) - } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + } else if !errors.Is(err, context.DeadlineExceeded) { return fmt.Errorf("error reading: %v", err) } @@ -547,12 +564,12 @@ func (FilterInputJumpReturnDrop) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputJumpReturnDrop) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, dropPort, sendloopDuration) +func (FilterInputJumpReturnDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) } // FilterInputJumpBuiltin verifies that jumping to a top-levl chain is illegal. -type FilterInputJumpBuiltin struct{} +type FilterInputJumpBuiltin struct{ baseCase } // Name implements TestCase.Name. func (FilterInputJumpBuiltin) Name() string { @@ -560,7 +577,7 @@ func (FilterInputJumpBuiltin) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpBuiltin) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputJumpBuiltin) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "INPUT", "-j", "OUTPUT"); err == nil { return fmt.Errorf("iptables should be unable to jump to a built-in chain") } @@ -568,13 +585,13 @@ func (FilterInputJumpBuiltin) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputJumpBuiltin) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterInputJumpBuiltin) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // FilterInputJumpTwice jumps twice, then returns twice and executes a rule. -type FilterInputJumpTwice struct{} +type FilterInputJumpTwice struct{ containerCase } // Name implements TestCase.Name. func (FilterInputJumpTwice) Name() string { @@ -582,7 +599,7 @@ func (FilterInputJumpTwice) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpTwice) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputJumpTwice) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { const chainName2 = chainName + "2" rules := [][]string{ {"-P", "INPUT", "DROP"}, @@ -598,17 +615,17 @@ func (FilterInputJumpTwice) ContainerAction(ip net.IP, ipv6 bool) error { // UDP packets should jump and return twice, eventually hitting the // ACCEPT rule. - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputJumpTwice) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputJumpTwice) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputDestination verifies that we can filter packets via `-d // <ipaddr>`. -type FilterInputDestination struct{} +type FilterInputDestination struct{ containerCase } // Name implements TestCase.Name. func (FilterInputDestination) Name() string { @@ -616,7 +633,7 @@ func (FilterInputDestination) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDestination) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { addrs, err := localAddrs(ipv6) if err != nil { return err @@ -632,17 +649,17 @@ func (FilterInputDestination) ContainerAction(ip net.IP, ipv6 bool) error { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputDestination) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputDestination) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputInvertDestination verifies that we can filter packets via `! -d // <ipaddr>`. -type FilterInputInvertDestination struct{} +type FilterInputInvertDestination struct{ containerCase } // Name implements TestCase.Name. func (FilterInputInvertDestination) Name() string { @@ -650,7 +667,7 @@ func (FilterInputInvertDestination) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputInvertDestination) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputInvertDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Make INPUT's default action DROP, then ACCEPT all packets not bound // for 127.0.0.1. rules := [][]string{ @@ -661,17 +678,17 @@ func (FilterInputInvertDestination) ContainerAction(ip net.IP, ipv6 bool) error return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputInvertDestination) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputInvertDestination) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputSource verifies that we can filter packets via `-s // <ipaddr>`. -type FilterInputSource struct{} +type FilterInputSource struct{ containerCase } // Name implements TestCase.Name. func (FilterInputSource) Name() string { @@ -679,7 +696,7 @@ func (FilterInputSource) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputSource) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputSource) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Make INPUT's default action DROP, then ACCEPT all packets from this // machine. rules := [][]string{ @@ -690,17 +707,17 @@ func (FilterInputSource) ContainerAction(ip net.IP, ipv6 bool) error { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputSource) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputSource) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputInvertSource verifies that we can filter packets via `! -s // <ipaddr>`. -type FilterInputInvertSource struct{} +type FilterInputInvertSource struct{ containerCase } // Name implements TestCase.Name. func (FilterInputInvertSource) Name() string { @@ -708,7 +725,7 @@ func (FilterInputInvertSource) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputInvertSource) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputInvertSource) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Make INPUT's default action DROP, then ACCEPT all packets not bound // for 127.0.0.1. rules := [][]string{ @@ -719,10 +736,10 @@ func (FilterInputInvertSource) ContainerAction(ip net.IP, ipv6 bool) error { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputInvertSource) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputInvertSource) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go index c1d83b471..32bf2a992 100644 --- a/test/iptables/filter_output.go +++ b/test/iptables/filter_output.go @@ -15,6 +15,8 @@ package iptables import ( + "context" + "errors" "fmt" "net" ) @@ -44,7 +46,7 @@ func init() { // FilterOutputDropTCPDestPort tests that connections are not accepted on // specified source ports. -type FilterOutputDropTCPDestPort struct{} +type FilterOutputDropTCPDestPort struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputDropTCPDestPort) Name() string { @@ -52,22 +54,28 @@ func (FilterOutputDropTCPDestPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropTCPDestPort) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputDropTCPDestPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", "1024:65535", "-j", "DROP"); err != nil { return err } // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDropTCPDestPort) LocalAction(ip net.IP, ipv6 bool) error { - if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { +func (FilterOutputDropTCPDestPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort) } @@ -76,7 +84,7 @@ func (FilterOutputDropTCPDestPort) LocalAction(ip net.IP, ipv6 bool) error { // FilterOutputDropTCPSrcPort tests that connections are not accepted on // specified source ports. -type FilterOutputDropTCPSrcPort struct{} +type FilterOutputDropTCPSrcPort struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputDropTCPSrcPort) Name() string { @@ -84,22 +92,28 @@ func (FilterOutputDropTCPSrcPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropTCPSrcPort) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputDropTCPSrcPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--sport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { return err } // Listen for TCP packets on drop port. - if err := listenTCP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, dropPort); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDropTCPSrcPort) LocalAction(ip net.IP, ipv6 bool) error { - if err := connectTCP(ip, dropPort, sendloopDuration); err == nil { +func (FilterOutputDropTCPSrcPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, dropPort); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort) } @@ -107,7 +121,7 @@ func (FilterOutputDropTCPSrcPort) LocalAction(ip net.IP, ipv6 bool) error { } // FilterOutputAcceptTCPOwner tests that TCP connections from uid owner are accepted. -type FilterOutputAcceptTCPOwner struct{} +type FilterOutputAcceptTCPOwner struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputAcceptTCPOwner) Name() string { @@ -115,22 +129,22 @@ func (FilterOutputAcceptTCPOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputAcceptTCPOwner) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputAcceptTCPOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { return err } // Listen for TCP packets on accept port. - return listenTCP(acceptPort, sendloopDuration) + return listenTCP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputAcceptTCPOwner) LocalAction(ip net.IP, ipv6 bool) error { - return connectTCP(ip, acceptPort, sendloopDuration) +func (FilterOutputAcceptTCPOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) } // FilterOutputDropTCPOwner tests that TCP connections from uid owner are dropped. -type FilterOutputDropTCPOwner struct{} +type FilterOutputDropTCPOwner struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputDropTCPOwner) Name() string { @@ -138,22 +152,28 @@ func (FilterOutputDropTCPOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropTCPOwner) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputDropTCPOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { return err } // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { return fmt.Errorf("connection on port %d should be dropped, but got accepted", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDropTCPOwner) LocalAction(ip net.IP, ipv6 bool) error { - if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { +func (FilterOutputDropTCPOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { return fmt.Errorf("connection destined to port %d should be dropped, but got accepted", acceptPort) } @@ -161,7 +181,7 @@ func (FilterOutputDropTCPOwner) LocalAction(ip net.IP, ipv6 bool) error { } // FilterOutputAcceptUDPOwner tests that UDP packets from uid owner are accepted. -type FilterOutputAcceptUDPOwner struct{} +type FilterOutputAcceptUDPOwner struct{ localCase } // Name implements TestCase.Name. func (FilterOutputAcceptUDPOwner) Name() string { @@ -169,23 +189,23 @@ func (FilterOutputAcceptUDPOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputAcceptUDPOwner) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputAcceptUDPOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { return err } // Send UDP packets on acceptPort. - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputAcceptUDPOwner) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterOutputAcceptUDPOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Listen for UDP packets on acceptPort. - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // FilterOutputDropUDPOwner tests that UDP packets from uid owner are dropped. -type FilterOutputDropUDPOwner struct{} +type FilterOutputDropUDPOwner struct{ localCase } // Name implements TestCase.Name. func (FilterOutputDropUDPOwner) Name() string { @@ -193,20 +213,24 @@ func (FilterOutputDropUDPOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropUDPOwner) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputDropUDPOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { return err } // Send UDP packets on dropPort. - return sendUDPLoop(ip, dropPort, sendloopDuration) + return sendUDPLoop(ctx, ip, dropPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDropUDPOwner) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterOutputDropUDPOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Listen for UDP packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { return fmt.Errorf("packets should not be received") + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil @@ -214,7 +238,7 @@ func (FilterOutputDropUDPOwner) LocalAction(ip net.IP, ipv6 bool) error { // FilterOutputOwnerFail tests that without uid/gid option, owner rule // will fail. -type FilterOutputOwnerFail struct{} +type FilterOutputOwnerFail struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputOwnerFail) Name() string { @@ -222,7 +246,7 @@ func (FilterOutputOwnerFail) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputOwnerFail) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputOwnerFail) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-m", "owner", "-j", "ACCEPT"); err == nil { return fmt.Errorf("Invalid argument") } @@ -231,13 +255,13 @@ func (FilterOutputOwnerFail) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputOwnerFail) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterOutputOwnerFail) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // no-op. return nil } // FilterOutputAcceptGIDOwner tests that TCP connections from gid owner are accepted. -type FilterOutputAcceptGIDOwner struct{} +type FilterOutputAcceptGIDOwner struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputAcceptGIDOwner) Name() string { @@ -245,22 +269,22 @@ func (FilterOutputAcceptGIDOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputAcceptGIDOwner) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputAcceptGIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--gid-owner", "root", "-j", "ACCEPT"); err != nil { return err } // Listen for TCP packets on accept port. - return listenTCP(acceptPort, sendloopDuration) + return listenTCP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputAcceptGIDOwner) LocalAction(ip net.IP, ipv6 bool) error { - return connectTCP(ip, acceptPort, sendloopDuration) +func (FilterOutputAcceptGIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) } // FilterOutputDropGIDOwner tests that TCP connections from gid owner are dropped. -type FilterOutputDropGIDOwner struct{} +type FilterOutputDropGIDOwner struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputDropGIDOwner) Name() string { @@ -268,22 +292,28 @@ func (FilterOutputDropGIDOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropGIDOwner) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputDropGIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--gid-owner", "root", "-j", "DROP"); err != nil { return err } // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDropGIDOwner) LocalAction(ip net.IP, ipv6 bool) error { - if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { +func (FilterOutputDropGIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) } @@ -291,7 +321,7 @@ func (FilterOutputDropGIDOwner) LocalAction(ip net.IP, ipv6 bool) error { } // FilterOutputInvertGIDOwner tests that TCP connections from gid owner are dropped. -type FilterOutputInvertGIDOwner struct{} +type FilterOutputInvertGIDOwner struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputInvertGIDOwner) Name() string { @@ -299,7 +329,7 @@ func (FilterOutputInvertGIDOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInvertGIDOwner) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputInvertGIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "OUTPUT", "-p", "tcp", "-m", "owner", "!", "--gid-owner", "root", "-j", "ACCEPT"}, {"-A", "OUTPUT", "-p", "tcp", "-j", "DROP"}, @@ -309,16 +339,22 @@ func (FilterOutputInvertGIDOwner) ContainerAction(ip net.IP, ipv6 bool) error { } // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInvertGIDOwner) LocalAction(ip net.IP, ipv6 bool) error { - if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { +func (FilterOutputInvertGIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) } @@ -326,7 +362,7 @@ func (FilterOutputInvertGIDOwner) LocalAction(ip net.IP, ipv6 bool) error { } // FilterOutputInvertUIDOwner tests that TCP connections from gid owner are dropped. -type FilterOutputInvertUIDOwner struct{} +type FilterOutputInvertUIDOwner struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputInvertUIDOwner) Name() string { @@ -334,7 +370,7 @@ func (FilterOutputInvertUIDOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInvertUIDOwner) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputInvertUIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "OUTPUT", "-p", "tcp", "-m", "owner", "!", "--uid-owner", "root", "-j", "DROP"}, {"-A", "OUTPUT", "-p", "tcp", "-j", "ACCEPT"}, @@ -344,17 +380,17 @@ func (FilterOutputInvertUIDOwner) ContainerAction(ip net.IP, ipv6 bool) error { } // Listen for TCP packets on accept port. - return listenTCP(acceptPort, sendloopDuration) + return listenTCP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInvertUIDOwner) LocalAction(ip net.IP, ipv6 bool) error { - return connectTCP(ip, acceptPort, sendloopDuration) +func (FilterOutputInvertUIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) } // FilterOutputInvertUIDAndGIDOwner tests that TCP connections from uid and gid // owner are dropped. -type FilterOutputInvertUIDAndGIDOwner struct{} +type FilterOutputInvertUIDAndGIDOwner struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputInvertUIDAndGIDOwner) Name() string { @@ -362,7 +398,7 @@ func (FilterOutputInvertUIDAndGIDOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInvertUIDAndGIDOwner) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputInvertUIDAndGIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "OUTPUT", "-p", "tcp", "-m", "owner", "!", "--uid-owner", "root", "!", "--gid-owner", "root", "-j", "ACCEPT"}, {"-A", "OUTPUT", "-p", "tcp", "-j", "DROP"}, @@ -372,16 +408,22 @@ func (FilterOutputInvertUIDAndGIDOwner) ContainerAction(ip net.IP, ipv6 bool) er } // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInvertUIDAndGIDOwner) LocalAction(ip net.IP, ipv6 bool) error { - if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { +func (FilterOutputInvertUIDAndGIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) } @@ -390,7 +432,7 @@ func (FilterOutputInvertUIDAndGIDOwner) LocalAction(ip net.IP, ipv6 bool) error // FilterOutputDestination tests that we can selectively allow packets to // certain destinations. -type FilterOutputDestination struct{} +type FilterOutputDestination struct{ localCase } // Name implements TestCase.Name. func (FilterOutputDestination) Name() string { @@ -398,7 +440,7 @@ func (FilterOutputDestination) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDestination) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"}, {"-P", "OUTPUT", "DROP"}, @@ -407,17 +449,17 @@ func (FilterOutputDestination) ContainerAction(ip net.IP, ipv6 bool) error { return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDestination) LocalAction(ip net.IP, ipv6 bool) error { - return listenUDP(acceptPort, sendloopDuration) +func (FilterOutputDestination) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenUDP(ctx, acceptPort) } // FilterOutputInvertDestination tests that we can selectively allow packets // not headed for a particular destination. -type FilterOutputInvertDestination struct{} +type FilterOutputInvertDestination struct{ localCase } // Name implements TestCase.Name. func (FilterOutputInvertDestination) Name() string { @@ -425,7 +467,7 @@ func (FilterOutputInvertDestination) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInvertDestination) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputInvertDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "OUTPUT", "!", "-d", localIP(ipv6), "-j", "ACCEPT"}, {"-P", "OUTPUT", "DROP"}, @@ -434,17 +476,17 @@ func (FilterOutputInvertDestination) ContainerAction(ip net.IP, ipv6 bool) error return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInvertDestination) LocalAction(ip net.IP, ipv6 bool) error { - return listenUDP(acceptPort, sendloopDuration) +func (FilterOutputInvertDestination) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenUDP(ctx, acceptPort) } // FilterOutputInterfaceAccept tests that packets are sent via interface // matching the iptables rule. -type FilterOutputInterfaceAccept struct{} +type FilterOutputInterfaceAccept struct{ localCase } // Name implements TestCase.Name. func (FilterOutputInterfaceAccept) Name() string { @@ -452,7 +494,7 @@ func (FilterOutputInterfaceAccept) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterfaceAccept) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputInterfaceAccept) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { ifname, ok := getInterfaceName() if !ok { return fmt.Errorf("no interface is present, except loopback") @@ -461,17 +503,17 @@ func (FilterOutputInterfaceAccept) ContainerAction(ip net.IP, ipv6 bool) error { return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterfaceAccept) LocalAction(ip net.IP, ipv6 bool) error { - return listenUDP(acceptPort, sendloopDuration) +func (FilterOutputInterfaceAccept) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenUDP(ctx, acceptPort) } // FilterOutputInterfaceDrop tests that packets are not sent via interface // matching the iptables rule. -type FilterOutputInterfaceDrop struct{} +type FilterOutputInterfaceDrop struct{ localCase } // Name implements TestCase.Name. func (FilterOutputInterfaceDrop) Name() string { @@ -479,7 +521,7 @@ func (FilterOutputInterfaceDrop) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterfaceDrop) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputInterfaceDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { ifname, ok := getInterfaceName() if !ok { return fmt.Errorf("no interface is present, except loopback") @@ -488,13 +530,17 @@ func (FilterOutputInterfaceDrop) ContainerAction(ip net.IP, ipv6 bool) error { return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterfaceDrop) LocalAction(ip net.IP, ipv6 bool) error { - if err := listenUDP(acceptPort, sendloopDuration); err == nil { +func (FilterOutputInterfaceDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, acceptPort); err == nil { return fmt.Errorf("packets should not be received on port %v, but are received", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil @@ -502,7 +548,7 @@ func (FilterOutputInterfaceDrop) LocalAction(ip net.IP, ipv6 bool) error { // FilterOutputInterface tests that packets are sent via interface which is // not matching the interface name in the iptables rule. -type FilterOutputInterface struct{} +type FilterOutputInterface struct{ localCase } // Name implements TestCase.Name. func (FilterOutputInterface) Name() string { @@ -510,22 +556,22 @@ func (FilterOutputInterface) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterface) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputInterface) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-o", "lo", "-j", "DROP"); err != nil { return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterface) LocalAction(ip net.IP, ipv6 bool) error { - return listenUDP(acceptPort, sendloopDuration) +func (FilterOutputInterface) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenUDP(ctx, acceptPort) } // FilterOutputInterfaceBeginsWith tests that packets are not sent via an // interface which begins with the given interface name. -type FilterOutputInterfaceBeginsWith struct{} +type FilterOutputInterfaceBeginsWith struct{ localCase } // Name implements TestCase.Name. func (FilterOutputInterfaceBeginsWith) Name() string { @@ -533,18 +579,22 @@ func (FilterOutputInterfaceBeginsWith) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterfaceBeginsWith) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputInterfaceBeginsWith) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-o", "e+", "-j", "DROP"); err != nil { return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterfaceBeginsWith) LocalAction(ip net.IP, ipv6 bool) error { - if err := listenUDP(acceptPort, sendloopDuration); err == nil { +func (FilterOutputInterfaceBeginsWith) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, acceptPort); err == nil { return fmt.Errorf("packets should not be received on port %v, but are received", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil @@ -552,7 +602,7 @@ func (FilterOutputInterfaceBeginsWith) LocalAction(ip net.IP, ipv6 bool) error { // FilterOutputInterfaceInvertDrop tests that we selectively do not send // packets via interface not matching the interface name. -type FilterOutputInterfaceInvertDrop struct{} +type FilterOutputInterfaceInvertDrop struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputInterfaceInvertDrop) Name() string { @@ -560,22 +610,28 @@ func (FilterOutputInterfaceInvertDrop) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterfaceInvertDrop) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputInterfaceInvertDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "DROP"); err != nil { return err } // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterfaceInvertDrop) LocalAction(ip net.IP, ipv6 bool) error { - if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { +func (FilterOutputInterfaceInvertDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) } @@ -584,7 +640,7 @@ func (FilterOutputInterfaceInvertDrop) LocalAction(ip net.IP, ipv6 bool) error { // FilterOutputInterfaceInvertAccept tests that we can selectively send packets // not matching the specific outgoing interface. -type FilterOutputInterfaceInvertAccept struct{} +type FilterOutputInterfaceInvertAccept struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputInterfaceInvertAccept) Name() string { @@ -592,16 +648,16 @@ func (FilterOutputInterfaceInvertAccept) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterfaceInvertAccept) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputInterfaceInvertAccept) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "ACCEPT"); err != nil { return err } // Listen for TCP packets on accept port. - return listenTCP(acceptPort, sendloopDuration) + return listenTCP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterfaceInvertAccept) LocalAction(ip net.IP, ipv6 bool) error { - return connectTCP(ip, acceptPort, sendloopDuration) +func (FilterOutputInterfaceInvertAccept) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) } diff --git a/test/iptables/iptables.go b/test/iptables/iptables.go index dfbd80cd1..c2a03f54c 100644 --- a/test/iptables/iptables.go +++ b/test/iptables/iptables.go @@ -16,6 +16,7 @@ package iptables import ( + "context" "fmt" "net" "time" @@ -29,7 +30,11 @@ const IPExchangePort = 2349 const TerminalStatement = "Finished!" // TestTimeout is the timeout used for all tests. -const TestTimeout = 10 * time.Minute +const TestTimeout = 10 * time.Second + +// NegativeTimeout is the time tests should wait to establish the negative +// case, i.e. that connections are not made. +const NegativeTimeout = 2 * time.Second // A TestCase contains one action to run in the container and one to run // locally. The actions run concurrently and each must succeed for the test @@ -40,10 +45,60 @@ type TestCase interface { // ContainerAction runs inside the container. It receives the IP of the // local process. - ContainerAction(ip net.IP, ipv6 bool) error + ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error // LocalAction runs locally. It receives the IP of the container. - LocalAction(ip net.IP, ipv6 bool) error + LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error + + // ContainerSufficient indicates whether ContainerAction's return value + // alone indicates whether the test succeeded. + ContainerSufficient() bool + + // LocalSufficient indicates whether LocalAction's return value alone + // indicates whether the test succeeded. + LocalSufficient() bool +} + +// baseCase provides defaults for ContainerSufficient and LocalSufficient when +// both actions are required to finish. +type baseCase struct{} + +// ContainerSufficient implements TestCase.ContainerSufficient. +func (baseCase) ContainerSufficient() bool { + return false +} + +// LocalSufficient implements TestCase.LocalSufficient. +func (baseCase) LocalSufficient() bool { + return false +} + +// localCase provides defaults for ContainerSufficient and LocalSufficient when +// only the local action is required to finish. +type localCase struct{} + +// ContainerSufficient implements TestCase.ContainerSufficient. +func (localCase) ContainerSufficient() bool { + return false +} + +// LocalSufficient implements TestCase.LocalSufficient. +func (localCase) LocalSufficient() bool { + return true +} + +// containerCase provides defaults for ContainerSufficient and LocalSufficient +// when only the container action is required to finish. +type containerCase struct{} + +// ContainerSufficient implements TestCase.ContainerSufficient. +func (containerCase) ContainerSufficient() bool { + return true +} + +// LocalSufficient implements TestCase.LocalSufficient. +func (containerCase) LocalSufficient() bool { + return false } // Tests maps test names to TestCase. diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index fda5f694f..e2beb30d5 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -16,9 +16,11 @@ package iptables import ( "context" + "errors" "fmt" "net" "reflect" + "sync" "testing" "gvisor.dev/gvisor/pkg/test/dockerutil" @@ -51,9 +53,24 @@ func iptablesTest(t *testing.T, test TestCase, ipv6 bool) { t.Fatalf("no test found with name %q. Has it been registered?", test.Name()) } - ctx := context.Background() + // Wait for the local and container goroutines to finish. + var wg sync.WaitGroup + defer wg.Wait() + + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + d := dockerutil.MakeContainer(ctx, t) - defer d.CleanUp(ctx) + defer func() { + if logs, err := d.Logs(context.Background()); err != nil { + t.Logf("Failed to retrieve container logs.") + } else { + t.Logf("=== Container logs: ===\n%s", logs) + } + // Use a new context, as cleanup should run even when we + // timeout. + d.CleanUp(context.Background()) + }() // TODO(gvisor.dev/issue/170): Skipping IPv6 gVisor tests. if ipv6 && dockerutil.Runtime() != "runc" { @@ -86,15 +103,44 @@ func iptablesTest(t *testing.T, test TestCase, ipv6 bool) { } // Run our side of the test. - if err := test.LocalAction(ip, ipv6); err != nil { - t.Fatalf("LocalAction failed: %v", err) - } - - // Wait for the final statement. This structure has the side effect - // that all container logs will appear within the individual test - // context. - if _, err := d.WaitForOutput(ctx, TerminalStatement, TestTimeout); err != nil { - t.Fatalf("test failed: %v", err) + errCh := make(chan error, 2) + wg.Add(1) + go func() { + defer wg.Done() + if err := test.LocalAction(ctx, ip, ipv6); err != nil && !errors.Is(err, context.Canceled) { + errCh <- fmt.Errorf("LocalAction failed: %v", err) + } else { + errCh <- nil + } + if test.LocalSufficient() { + errCh <- nil + } + }() + + // Run the container side. + wg.Add(1) + go func() { + defer wg.Done() + // Wait for the final statement. This structure has the side + // effect that all container logs will appear within the + // individual test context. + if _, err := d.WaitForOutput(ctx, TerminalStatement, TestTimeout); err != nil && !errors.Is(err, context.Canceled) { + errCh <- fmt.Errorf("ContainerAction failed: %v", err) + } else { + errCh <- nil + } + if test.ContainerSufficient() { + errCh <- nil + } + }() + + for i := 0; i < 2; i++ { + select { + case err := <-errCh: + if err != nil { + t.Fatal(err) + } + } } } diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go index 5125fe47b..a6ec5cca3 100644 --- a/test/iptables/iptables_util.go +++ b/test/iptables/iptables_util.go @@ -15,6 +15,7 @@ package iptables import ( + "context" "encoding/binary" "errors" "fmt" @@ -70,7 +71,7 @@ func tableRules(ipv6 bool, table string, argsList [][]string) error { // listenUDP listens on a UDP port and returns the value of net.Conn.Read() for // the first read on that port. -func listenUDP(port int, timeout time.Duration) error { +func listenUDP(ctx context.Context, port int) error { localAddr := net.UDPAddr{ Port: port, } @@ -79,68 +80,53 @@ func listenUDP(port int, timeout time.Duration) error { return err } defer conn.Close() - conn.SetDeadline(time.Now().Add(timeout)) - _, err = conn.Read([]byte{0}) - return err -} -// sendUDPLoop sends 1 byte UDP packets repeatedly to the IP and port specified -// over a duration. -func sendUDPLoop(ip net.IP, port int, duration time.Duration) error { - conn, err := connectUDP(ip, port) - if err != nil { - return err - } - defer conn.Close() - loopUDP(conn, duration) - return nil -} + ch := make(chan error) + go func() { + _, err = conn.Read([]byte{0}) + ch <- err + }() -// spawnUDPLoop works like sendUDPLoop, but returns immediately and sends -// packets in another goroutine. -func spawnUDPLoop(ip net.IP, port int, duration time.Duration) error { - conn, err := connectUDP(ip, port) - if err != nil { + select { + case err := <-ch: return err + case <-ctx.Done(): + return ctx.Err() } - go func() { - defer conn.Close() - loopUDP(conn, duration) - }() - return nil } -func connectUDP(ip net.IP, port int) (net.Conn, error) { +// sendUDPLoop sends 1 byte UDP packets repeatedly to the IP and port specified +// over a duration. +func sendUDPLoop(ctx context.Context, ip net.IP, port int) error { remote := net.UDPAddr{ IP: ip, Port: port, } conn, err := net.DialUDP("udp", nil, &remote) if err != nil { - return nil, err + return err } - return conn, nil -} + defer conn.Close() -func loopUDP(conn net.Conn, duration time.Duration) { - to := time.After(duration) - for timedOut := false; !timedOut; { + for { // This may return an error (connection refused) if the remote // hasn't started listening yet or they're dropping our // packets. So we ignore Write errors and depend on the remote // to report a failure if it doesn't get a packet it needs. conn.Write([]byte{0}) select { - case <-to: - timedOut = true - default: - time.Sleep(200 * time.Millisecond) + case <-ctx.Done(): + // Being cancelled or timing out isn't an error, as we + // cannot tell with UDP whether we succeeded. + return nil + // Continue looping. + case <-time.After(200 * time.Millisecond): } } } // listenTCP listens for connections on a TCP port. -func listenTCP(port int, timeout time.Duration) error { +func listenTCP(ctx context.Context, port int) error { localAddr := net.TCPAddr{ Port: port, } @@ -153,17 +139,23 @@ func listenTCP(port int, timeout time.Duration) error { defer lConn.Close() // Accept connections on port. - lConn.SetDeadline(time.Now().Add(timeout)) - conn, err := lConn.AcceptTCP() - if err != nil { + ch := make(chan error) + go func() { + conn, err := lConn.AcceptTCP() + ch <- err + conn.Close() + }() + + select { + case err := <-ch: return err + case <-ctx.Done(): + return fmt.Errorf("timed out waiting for a connection at %#v: %w", localAddr, ctx.Err()) } - conn.Close() - return nil } // connectTCP connects to the given IP and port from an ephemeral local address. -func connectTCP(ip net.IP, port int, timeout time.Duration) error { +func connectTCP(ctx context.Context, ip net.IP, port int) error { contAddr := net.TCPAddr{ IP: ip, Port: port, @@ -171,13 +163,14 @@ func connectTCP(ip net.IP, port int, timeout time.Duration) error { // The container may not be listening when we first connect, so retry // upon error. callback := func() error { - conn, err := net.DialTimeout("tcp", contAddr.String(), timeout) + var d net.Dialer + conn, err := d.DialContext(ctx, "tcp", contAddr.String()) if conn != nil { conn.Close() } return err } - if err := testutil.Poll(callback, timeout); err != nil { + if err := testutil.PollContext(ctx, callback); err != nil { return fmt.Errorf("timed out waiting to connect IP on port %v, most recent error: %v", port, err) } diff --git a/test/iptables/nat.go b/test/iptables/nat.go index b7fea2527..dd9a18339 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -15,11 +15,11 @@ package iptables import ( + "context" "errors" "fmt" "net" "syscall" - "time" ) const redirectPort = 42 @@ -46,7 +46,7 @@ func init() { } // NATPreRedirectUDPPort tests that packets are redirected to different port. -type NATPreRedirectUDPPort struct{} +type NATPreRedirectUDPPort struct{ containerCase } // Name implements TestCase.Name. func (NATPreRedirectUDPPort) Name() string { @@ -54,12 +54,12 @@ func (NATPreRedirectUDPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectUDPPort) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATPreRedirectUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { return err } - if err := listenUDP(redirectPort, sendloopDuration); err != nil { + if err := listenUDP(ctx, redirectPort); err != nil { return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", redirectPort, err) } @@ -67,12 +67,12 @@ func (NATPreRedirectUDPPort) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectUDPPort) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (NATPreRedirectUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // NATPreRedirectTCPPort tests that connections are redirected on specified ports. -type NATPreRedirectTCPPort struct{} +type NATPreRedirectTCPPort struct{ baseCase } // Name implements TestCase.Name. func (NATPreRedirectTCPPort) Name() string { @@ -80,23 +80,23 @@ func (NATPreRedirectTCPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectTCPPort) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATPreRedirectTCPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { return err } // Listen for TCP packets on redirect port. - return listenTCP(acceptPort, sendloopDuration) + return listenTCP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectTCPPort) LocalAction(ip net.IP, ipv6 bool) error { - return connectTCP(ip, dropPort, sendloopDuration) +func (NATPreRedirectTCPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, dropPort) } // NATPreRedirectTCPOutgoing verifies that outgoing TCP connections aren't // affected by PREROUTING connection tracking. -type NATPreRedirectTCPOutgoing struct{} +type NATPreRedirectTCPOutgoing struct{ baseCase } // Name implements TestCase.Name. func (NATPreRedirectTCPOutgoing) Name() string { @@ -104,24 +104,24 @@ func (NATPreRedirectTCPOutgoing) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectTCPOutgoing) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATPreRedirectTCPOutgoing) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect all incoming TCP traffic to a closed port. if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { return err } // Establish a connection to the host process. - return connectTCP(ip, acceptPort, sendloopDuration) + return connectTCP(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectTCPOutgoing) LocalAction(ip net.IP, ipv6 bool) error { - return listenTCP(acceptPort, sendloopDuration) +func (NATPreRedirectTCPOutgoing) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenTCP(ctx, acceptPort) } // NATOutRedirectTCPIncoming verifies that incoming TCP connections aren't // affected by OUTPUT connection tracking. -type NATOutRedirectTCPIncoming struct{} +type NATOutRedirectTCPIncoming struct{ baseCase } // Name implements TestCase.Name. func (NATOutRedirectTCPIncoming) Name() string { @@ -129,23 +129,23 @@ func (NATOutRedirectTCPIncoming) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectTCPIncoming) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectTCPIncoming) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect all outgoing TCP traffic to a closed port. if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { return err } // Establish a connection to the host process. - return listenTCP(acceptPort, sendloopDuration) + return listenTCP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectTCPIncoming) LocalAction(ip net.IP, ipv6 bool) error { - return connectTCP(ip, acceptPort, sendloopDuration) +func (NATOutRedirectTCPIncoming) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) } // NATOutRedirectUDPPort tests that packets are redirected to different port. -type NATOutRedirectUDPPort struct{} +type NATOutRedirectUDPPort struct{ containerCase } // Name implements TestCase.Name. func (NATOutRedirectUDPPort) Name() string { @@ -153,19 +153,19 @@ func (NATOutRedirectUDPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectUDPPort) ContainerAction(ip net.IP, ipv6 bool) error { - return loopbackTest(ipv6, net.ParseIP(nowhereIP(ipv6)), "-A", "OUTPUT", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)) +func (NATOutRedirectUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return loopbackTest(ctx, ipv6, net.ParseIP(nowhereIP(ipv6)), "-A", "OUTPUT", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)) } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectUDPPort) LocalAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // NATDropUDP tests that packets are not received in ports other than redirect // port. -type NATDropUDP struct{} +type NATDropUDP struct{ containerCase } // Name implements TestCase.Name. func (NATDropUDP) Name() string { @@ -173,25 +173,29 @@ func (NATDropUDP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATDropUDP) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATDropUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { return err } - if err := listenUDP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, acceptPort); err == nil { return fmt.Errorf("packets on port %d should have been redirected to port %d", acceptPort, redirectPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (NATDropUDP) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (NATDropUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // NATAcceptAll tests that all UDP packets are accepted. -type NATAcceptAll struct{} +type NATAcceptAll struct{ containerCase } // Name implements TestCase.Name. func (NATAcceptAll) Name() string { @@ -199,12 +203,12 @@ func (NATAcceptAll) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATAcceptAll) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATAcceptAll) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "ACCEPT"); err != nil { return err } - if err := listenUDP(acceptPort, sendloopDuration); err != nil { + if err := listenUDP(ctx, acceptPort); err != nil { return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", acceptPort, err) } @@ -212,13 +216,13 @@ func (NATAcceptAll) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (NATAcceptAll) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (NATAcceptAll) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // NATOutRedirectIP uses iptables to select packets based on destination IP and // redirects them. -type NATOutRedirectIP struct{} +type NATOutRedirectIP struct{ baseCase } // Name implements TestCase.Name. func (NATOutRedirectIP) Name() string { @@ -226,9 +230,9 @@ func (NATOutRedirectIP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect OUTPUT packets to a listening localhost port. - return loopbackTest(ipv6, net.ParseIP(nowhereIP(ipv6)), + return loopbackTest(ctx, ipv6, net.ParseIP(nowhereIP(ipv6)), "-A", "OUTPUT", "-d", nowhereIP(ipv6), "-p", "udp", @@ -236,14 +240,14 @@ func (NATOutRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectIP) LocalAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // NATOutDontRedirectIP tests that iptables matching with "-d" does not match // packets it shouldn't. -type NATOutDontRedirectIP struct{} +type NATOutDontRedirectIP struct{ localCase } // Name implements TestCase.Name. func (NATOutDontRedirectIP) Name() string { @@ -251,20 +255,20 @@ func (NATOutDontRedirectIP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutDontRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATOutDontRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "OUTPUT", "-d", localIP(ipv6), "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil { return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATOutDontRedirectIP) LocalAction(ip net.IP, ipv6 bool) error { - return listenUDP(acceptPort, sendloopDuration) +func (NATOutDontRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenUDP(ctx, acceptPort) } // NATOutRedirectInvert tests that iptables can match with "! -d". -type NATOutRedirectInvert struct{} +type NATOutRedirectInvert struct{ baseCase } // Name implements TestCase.Name. func (NATOutRedirectInvert) Name() string { @@ -272,13 +276,13 @@ func (NATOutRedirectInvert) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectInvert) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectInvert) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect OUTPUT packets to a listening localhost port. dest := "192.0.2.2" if ipv6 { dest = "2001:db8::2" } - return loopbackTest(ipv6, net.ParseIP(nowhereIP(ipv6)), + return loopbackTest(ctx, ipv6, net.ParseIP(nowhereIP(ipv6)), "-A", "OUTPUT", "!", "-d", dest, "-p", "udp", @@ -286,14 +290,14 @@ func (NATOutRedirectInvert) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectInvert) LocalAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectInvert) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // NATPreRedirectIP tests that we can use iptables to select packets based on // destination IP and redirect them. -type NATPreRedirectIP struct{} +type NATPreRedirectIP struct{ containerCase } // Name implements TestCase.Name. func (NATPreRedirectIP) Name() string { @@ -301,7 +305,7 @@ func (NATPreRedirectIP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATPreRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { addrs, err := localAddrs(ipv6) if err != nil { return err @@ -314,17 +318,17 @@ func (NATPreRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { if err := natTableRules(ipv6, rules); err != nil { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectIP) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, dropPort, sendloopDuration) +func (NATPreRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) } // NATPreDontRedirectIP tests that iptables matching with "-d" does not match // packets it shouldn't. -type NATPreDontRedirectIP struct{} +type NATPreDontRedirectIP struct{ containerCase } // Name implements TestCase.Name. func (NATPreDontRedirectIP) Name() string { @@ -332,20 +336,20 @@ func (NATPreDontRedirectIP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreDontRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATPreDontRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-d", localIP(ipv6), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATPreDontRedirectIP) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (NATPreDontRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // NATPreRedirectInvert tests that iptables can match with "! -d". -type NATPreRedirectInvert struct{} +type NATPreRedirectInvert struct{ containerCase } // Name implements TestCase.Name. func (NATPreRedirectInvert) Name() string { @@ -353,21 +357,21 @@ func (NATPreRedirectInvert) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectInvert) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATPreRedirectInvert) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "!", "-d", localIP(ipv6), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectInvert) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, dropPort, sendloopDuration) +func (NATPreRedirectInvert) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) } // NATRedirectRequiresProtocol tests that use of the --to-ports flag requires a // protocol to be specified with -p. -type NATRedirectRequiresProtocol struct{} +type NATRedirectRequiresProtocol struct{ baseCase } // Name implements TestCase.Name. func (NATRedirectRequiresProtocol) Name() string { @@ -375,7 +379,7 @@ func (NATRedirectRequiresProtocol) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATRedirectRequiresProtocol) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATRedirectRequiresProtocol) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "PREROUTING", "-d", localIP(ipv6), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err == nil { return errors.New("expected an error using REDIRECT --to-ports without a protocol") } @@ -383,13 +387,13 @@ func (NATRedirectRequiresProtocol) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (NATRedirectRequiresProtocol) LocalAction(ip net.IP, ipv6 bool) error { +func (NATRedirectRequiresProtocol) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // NATOutRedirectTCPPort tests that connections are redirected on specified ports. -type NATOutRedirectTCPPort struct{} +type NATOutRedirectTCPPort struct{ baseCase } // Name implements TestCase.Name. func (NATOutRedirectTCPPort) Name() string { @@ -397,12 +401,11 @@ func (NATOutRedirectTCPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectTCPPort) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectTCPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { return err } - timeout := 20 * time.Second localAddr := net.TCPAddr{ IP: net.ParseIP(localIP(ipv6)), Port: acceptPort, @@ -416,9 +419,7 @@ func (NATOutRedirectTCPPort) ContainerAction(ip net.IP, ipv6 bool) error { defer lConn.Close() // Accept connections on port. - lConn.SetDeadline(time.Now().Add(timeout)) - err = connectTCP(ip, dropPort, timeout) - if err != nil { + if err := connectTCP(ctx, ip, dropPort); err != nil { return err } @@ -432,13 +433,13 @@ func (NATOutRedirectTCPPort) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectTCPPort) LocalAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectTCPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { return nil } // NATLoopbackSkipsPrerouting tests that packets sent via loopback aren't // affected by PREROUTING rules. -type NATLoopbackSkipsPrerouting struct{} +type NATLoopbackSkipsPrerouting struct{ baseCase } // Name implements TestCase.Name. func (NATLoopbackSkipsPrerouting) Name() string { @@ -446,7 +447,7 @@ func (NATLoopbackSkipsPrerouting) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATLoopbackSkipsPrerouting) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATLoopbackSkipsPrerouting) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect anything sent to localhost to an unused port. dest := []byte{127, 0, 0, 1} if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil { @@ -457,24 +458,24 @@ func (NATLoopbackSkipsPrerouting) ContainerAction(ip net.IP, ipv6 bool) error { // loopback traffic, the connection would fail. sendCh := make(chan error) go func() { - sendCh <- connectTCP(dest, acceptPort, sendloopDuration) + sendCh <- connectTCP(ctx, dest, acceptPort) }() - if err := listenTCP(acceptPort, sendloopDuration); err != nil { + if err := listenTCP(ctx, acceptPort); err != nil { return err } return <-sendCh } // LocalAction implements TestCase.LocalAction. -func (NATLoopbackSkipsPrerouting) LocalAction(ip net.IP, ipv6 bool) error { +func (NATLoopbackSkipsPrerouting) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // NATPreOriginalDst tests that SO_ORIGINAL_DST returns the pre-NAT destination // of PREROUTING NATted packets. -type NATPreOriginalDst struct{} +type NATPreOriginalDst struct{ baseCase } // Name implements TestCase.Name. func (NATPreOriginalDst) Name() string { @@ -482,7 +483,7 @@ func (NATPreOriginalDst) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreOriginalDst) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATPreOriginalDst) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect incoming TCP connections to acceptPort. if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", @@ -495,17 +496,17 @@ func (NATPreOriginalDst) ContainerAction(ip net.IP, ipv6 bool) error { if err != nil { return err } - return listenForRedirectedConn(ipv6, addrs) + return listenForRedirectedConn(ctx, ipv6, addrs) } // LocalAction implements TestCase.LocalAction. -func (NATPreOriginalDst) LocalAction(ip net.IP, ipv6 bool) error { - return connectTCP(ip, dropPort, sendloopDuration) +func (NATPreOriginalDst) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, dropPort) } // NATOutOriginalDst tests that SO_ORIGINAL_DST returns the pre-NAT destination // of OUTBOUND NATted packets. -type NATOutOriginalDst struct{} +type NATOutOriginalDst struct{ baseCase } // Name implements TestCase.Name. func (NATOutOriginalDst) Name() string { @@ -513,7 +514,7 @@ func (NATOutOriginalDst) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutOriginalDst) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATOutOriginalDst) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect incoming TCP connections to acceptPort. if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)); err != nil { return err @@ -521,22 +522,22 @@ func (NATOutOriginalDst) ContainerAction(ip net.IP, ipv6 bool) error { connCh := make(chan error) go func() { - connCh <- connectTCP(ip, dropPort, sendloopDuration) + connCh <- connectTCP(ctx, ip, dropPort) }() - if err := listenForRedirectedConn(ipv6, []net.IP{ip}); err != nil { + if err := listenForRedirectedConn(ctx, ipv6, []net.IP{ip}); err != nil { return err } return <-connCh } // LocalAction implements TestCase.LocalAction. -func (NATOutOriginalDst) LocalAction(ip net.IP, ipv6 bool) error { +func (NATOutOriginalDst) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } -func listenForRedirectedConn(ipv6 bool, originalDsts []net.IP) error { +func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net.IP) error { // The net package doesn't give guarantee access to the connection's // underlying FD, and thus we cannot call getsockopt. We have to use // traditional syscalls for SO_ORIGINAL_DST. @@ -572,16 +573,32 @@ func listenForRedirectedConn(ipv6 bool, originalDsts []net.IP) error { return err } - connfd, _, err := syscall.Accept(sockfd) - if err != nil { + // Block on accept() in another goroutine. + connCh := make(chan int) + errCh := make(chan error) + go func() { + connFD, _, err := syscall.Accept(sockfd) + if err != nil { + errCh <- err + } + connCh <- connFD + }() + + // Wait for accept() to return or for the context to finish. + var connFD int + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: return err + case connFD = <-connCh: } - defer syscall.Close(connfd) + defer syscall.Close(connFD) // Verify that, despite listening on acceptPort, SO_ORIGINAL_DST // indicates the packet was sent to originalDst:dropPort. if ipv6 { - got, err := originalDestination6(connfd) + got, err := originalDestination6(connFD) if err != nil { return err } @@ -598,7 +615,7 @@ func listenForRedirectedConn(ipv6 bool, originalDsts []net.IP) error { } return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts) } else { - got, err := originalDestination4(connfd) + got, err := originalDestination4(connFD) if err != nil { return err } @@ -619,26 +636,22 @@ func listenForRedirectedConn(ipv6 bool, originalDsts []net.IP) error { // loopbackTests runs an iptables rule and ensures that packets sent to // dest:dropPort are received by localhost:acceptPort. -func loopbackTest(ipv6 bool, dest net.IP, args ...string) error { +func loopbackTest(ctx context.Context, ipv6 bool, dest net.IP, args ...string) error { if err := natTable(ipv6, args...); err != nil { return err } - sendCh := make(chan error) - listenCh := make(chan error) + sendCh := make(chan error, 1) + listenCh := make(chan error, 1) go func() { - sendCh <- sendUDPLoop(dest, dropPort, sendloopDuration) + sendCh <- sendUDPLoop(ctx, dest, dropPort) }() go func() { - listenCh <- listenUDP(acceptPort, sendloopDuration) + listenCh <- listenUDP(ctx, acceptPort) }() select { case err := <-listenCh: - if err != nil { - return err - } - case <-time.After(sendloopDuration): - return errors.New("timed out") + return err + case err := <-sendCh: + return err } - // sendCh will always take the full sendloop time. - return <-sendCh } diff --git a/test/iptables/runner/main.go b/test/iptables/runner/main.go index 69d3ef121..9ae2d1b4d 100644 --- a/test/iptables/runner/main.go +++ b/test/iptables/runner/main.go @@ -16,6 +16,7 @@ package main import ( + "context" "flag" "fmt" "log" @@ -46,7 +47,9 @@ func main() { } // Run the test. - if err := test.ContainerAction(ip, *ipv6); err != nil { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := test.ContainerAction(ctx, ip, *ipv6); err != nil { log.Fatalf("Failed running test %q: %v", *name, err) } |