diff options
Diffstat (limited to 'pkg/tcpip/stack/stack_test.go')
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 68 |
1 files changed, 17 insertions, 51 deletions
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 1deeccb89..60b54c244 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -158,23 +158,13 @@ func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack func (*fakeNetworkEndpoint) Close() {} -type fakeNetGoodOption bool - -type fakeNetBadOption bool - -type fakeNetInvalidValueOption int - -type fakeNetOptions struct { - good bool -} - // fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the // number of packets sent and received via endpoints of this protocol. The index // where packets are added is given by the packet's destination address MOD 10. type fakeNetworkProtocol struct { packetCount [10]int sendPacketCount [10]int - opts fakeNetOptions + defaultTTL uint8 } func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { @@ -206,22 +196,20 @@ func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ stack.LinkAddress } } -func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error { +func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { switch v := option.(type) { - case fakeNetGoodOption: - f.opts.good = bool(v) + case *tcpip.DefaultTTLOption: + f.defaultTTL = uint8(*v) return nil - case fakeNetInvalidValueOption: - return tcpip.ErrInvalidOptionValue default: return tcpip.ErrUnknownProtocolOption } } -func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { +func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error { switch v := option.(type) { - case *fakeNetGoodOption: - *v = fakeNetGoodOption(f.opts.good) + case *tcpip.DefaultTTLOption: + *v = tcpip.DefaultTTLOption(f.defaultTTL) return nil default: return tcpip.ErrUnknownProtocolOption @@ -1640,46 +1628,24 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { } } -func TestNetworkOptions(t *testing.T) { +func TestNetworkOption(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, TransportProtocols: []stack.TransportProtocol{}, }) - // Try an unsupported network protocol. - if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol { - t.Fatalf("SetNetworkProtocolOption(fakeNet2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err) + opt := tcpip.DefaultTTLOption(5) + if err := s.SetNetworkProtocolOption(fakeNetNumber, &opt); err != nil { + t.Fatalf("s.SetNetworkProtocolOption(%d, &%T(%d)): %s", fakeNetNumber, opt, opt, err) } - testCases := []struct { - option interface{} - wantErr *tcpip.Error - verifier func(t *testing.T, p stack.NetworkProtocol) - }{ - {fakeNetGoodOption(true), nil, func(t *testing.T, p stack.NetworkProtocol) { - t.Helper() - fakeNet := p.(*fakeNetworkProtocol) - if fakeNet.opts.good != true { - t.Fatalf("fakeNet.opts.good = false, want = true") - } - var v fakeNetGoodOption - if err := s.NetworkProtocolOption(fakeNetNumber, &v); err != nil { - t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) = %v, want = nil, where v is option %T", v, err) - } - if v != true { - t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) returned v = %v, want = true", v) - } - }}, - {fakeNetBadOption(true), tcpip.ErrUnknownProtocolOption, nil}, - {fakeNetInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil}, + var optGot tcpip.DefaultTTLOption + if err := s.NetworkProtocolOption(fakeNetNumber, &optGot); err != nil { + t.Fatalf("s.NetworkProtocolOption(%d, &%T): %s", fakeNetNumber, optGot, err) } - for _, tc := range testCases { - if got := s.SetNetworkProtocolOption(fakeNetNumber, tc.option); got != tc.wantErr { - t.Errorf("s.SetNetworkProtocolOption(fakeNet, %v) = %v, want = %v", tc.option, got, tc.wantErr) - } - if tc.verifier != nil { - tc.verifier(t, s.NetworkProtocolInstance(fakeNetNumber)) - } + + if opt != optGot { + t.Errorf("got optGot = %d, want = %d", optGot, opt) } } |