summaryrefslogtreecommitdiffhomepage
path: root/pkg/dhcp/dhcp_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/dhcp/dhcp_test.go')
-rw-r--r--pkg/dhcp/dhcp_test.go246
1 files changed, 223 insertions, 23 deletions
diff --git a/pkg/dhcp/dhcp_test.go b/pkg/dhcp/dhcp_test.go
index 731ed61a5..67814683a 100644
--- a/pkg/dhcp/dhcp_test.go
+++ b/pkg/dhcp/dhcp_test.go
@@ -27,9 +27,13 @@ import (
"gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
"gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
"gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
)
-func TestDHCP(t *testing.T) {
+const nicid = tcpip.NICID(1)
+const serverAddr = tcpip.Address("\xc0\xa8\x03\x01")
+
+func createStack(t *testing.T) *stack.Stack {
const defaultMTU = 65536
id, linkEP := channel.New(256, defaultMTU, "")
if testing.Verbose() {
@@ -48,17 +52,9 @@ func TestDHCP(t *testing.T) {
s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
- const nicid tcpip.NICID = 1
if err := s.CreateNIC(nicid, id); err != nil {
t.Fatal(err)
}
- if err := s.AddAddress(nicid, ipv4.ProtocolNumber, "\x00\x00\x00\x00"); err != nil {
- t.Fatal(err)
- }
- if err := s.AddAddress(nicid, ipv4.ProtocolNumber, "\xff\xff\xff\xff"); err != nil {
- t.Fatal(err)
- }
- const serverAddr = tcpip.Address("\xc0\xa8\x03\x01")
if err := s.AddAddress(nicid, ipv4.ProtocolNumber, serverAddr); err != nil {
t.Fatal(err)
}
@@ -70,31 +66,38 @@ func TestDHCP(t *testing.T) {
NIC: nicid,
}})
- var clientAddrs = []tcpip.Address{"\xc0\xa8\x03\x02", "\xc0\xa8\x03\x03"}
+ return s
+}
+
+func TestDHCP(t *testing.T) {
+ s := createStack(t)
+ clientAddrs := []tcpip.Address{"\xc0\xa8\x03\x02", "\xc0\xa8\x03\x03"}
serverCfg := Config{
- ServerAddress: serverAddr,
- SubnetMask: "\xff\xff\xff\x00",
- Gateway: "\xc0\xa8\x03\xF0",
- DomainNameServer: "\x08\x08\x08\x08",
- LeaseLength: 24 * time.Hour,
+ ServerAddress: serverAddr,
+ SubnetMask: "\xff\xff\xff\x00",
+ Gateway: "\xc0\xa8\x03\xF0",
+ DNS: []tcpip.Address{
+ "\x08\x08\x08\x08", "\x08\x08\x04\x04",
+ },
+ LeaseLength: 24 * time.Hour,
}
serverCtx, cancel := context.WithCancel(context.Background())
defer cancel()
- _, err := NewServer(serverCtx, s, clientAddrs, serverCfg)
+ _, err := newEPConnServer(serverCtx, s, clientAddrs, serverCfg)
if err != nil {
t.Fatal(err)
}
const clientLinkAddr0 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x52")
- c0 := NewClient(s, nicid, clientLinkAddr0)
- if err := c0.Request(context.Background(), ""); err != nil {
+ c0 := NewClient(s, nicid, clientLinkAddr0, nil)
+ if _, err := c0.Request(context.Background(), ""); err != nil {
t.Fatal(err)
}
if got, want := c0.Address(), clientAddrs[0]; got != want {
t.Errorf("c.Addr()=%s, want=%s", got, want)
}
- if err := c0.Request(context.Background(), ""); err != nil {
+ if _, err := c0.Request(context.Background(), ""); err != nil {
t.Fatal(err)
}
if got, want := c0.Address(), clientAddrs[0]; got != want {
@@ -102,22 +105,219 @@ func TestDHCP(t *testing.T) {
}
const clientLinkAddr1 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x53")
- c1 := NewClient(s, nicid, clientLinkAddr1)
- if err := c1.Request(context.Background(), ""); err != nil {
+ c1 := NewClient(s, nicid, clientLinkAddr1, nil)
+ if _, err := c1.Request(context.Background(), ""); err != nil {
t.Fatal(err)
}
if got, want := c1.Address(), clientAddrs[1]; got != want {
t.Errorf("c.Addr()=%s, want=%s", got, want)
}
- if err := c0.Request(context.Background(), ""); err != nil {
+ if _, err := c0.Request(context.Background(), ""); err != nil {
t.Fatal(err)
}
if got, want := c0.Address(), clientAddrs[0]; got != want {
t.Errorf("c.Addr()=%s, want=%s", got, want)
}
- if got, want := c0.Config(), serverCfg; got != want {
+ if got, want := c0.Config(), serverCfg; !equalConfig(got, want) {
t.Errorf("client config:\n\t%#+v\nwant:\n\t%#+v", got, want)
}
}
+
+func equalConfig(c0, c1 Config) bool {
+ if c0.Error != c1.Error || c0.ServerAddress != c1.ServerAddress || c0.SubnetMask != c1.SubnetMask || c0.Gateway != c1.Gateway || c0.LeaseLength != c1.LeaseLength {
+ return false
+ }
+ if len(c0.DNS) != len(c1.DNS) {
+ return false
+ }
+ for i := 0; i < len(c0.DNS); i++ {
+ if c0.DNS[i] != c1.DNS[i] {
+ return false
+ }
+ }
+ return true
+}
+
+func TestRenew(t *testing.T) {
+ s := createStack(t)
+ clientAddrs := []tcpip.Address{"\xc0\xa8\x03\x02"}
+
+ serverCfg := Config{
+ ServerAddress: serverAddr,
+ SubnetMask: "\xff\xff\xff\x00",
+ Gateway: "\xc0\xa8\x03\xF0",
+ DNS: []tcpip.Address{"\x08\x08\x08\x08"},
+ LeaseLength: 1 * time.Second,
+ }
+ serverCtx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ _, err := newEPConnServer(serverCtx, s, clientAddrs, serverCfg)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ count := 0
+ var curAddr tcpip.Address
+ addrCh := make(chan tcpip.Address)
+ acquiredFunc := func(oldAddr, newAddr tcpip.Address, cfg Config) {
+ if err := cfg.Error; err != nil {
+ t.Fatalf("acquisition %d failed: %v", count, err)
+ }
+ if oldAddr != curAddr {
+ t.Fatalf("aquisition %d: curAddr=%v, oldAddr=%v", count, curAddr, oldAddr)
+ }
+ if cfg.LeaseLength != time.Second {
+ t.Fatalf("aquisition %d: lease length: %v, want %v", count, cfg.LeaseLength, time.Second)
+ }
+ count++
+ curAddr = newAddr
+ addrCh <- newAddr
+ }
+
+ clientCtx, cancel := context.WithCancel(context.Background())
+ const clientLinkAddr0 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x52")
+ c := NewClient(s, nicid, clientLinkAddr0, acquiredFunc)
+ c.Run(clientCtx)
+
+ var addr tcpip.Address
+ select {
+ case addr = <-addrCh:
+ t.Logf("got first address: %v", addr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout acquiring initial address")
+ }
+
+ select {
+ case newAddr := <-addrCh:
+ t.Logf("got renewal: %v", newAddr)
+ if newAddr != addr {
+ t.Fatalf("renewal address is %v, want %v", newAddr, addr)
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout waiting for address renewal")
+ }
+
+ cancel()
+}
+
+// Regression test for https://fuchsia.atlassian.net/browse/NET-17
+func TestNoNullTerminator(t *testing.T) {
+ v := "\x02\x01\x06\x00" +
+ "\xc8\x37\xbe\x73\x00\x00\x80\x00\x00\x00\x00\x00\xc0\xa8\x2b\x92" +
+ "\xc0\xa8\x2b\x01\x00\x00\x00\x00\x00\x0f\x60\x0a\x23\x93\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x63\x82\x53\x63\x35\x01\x02\x36" +
+ "\x04\xc0\xa8\x2b\x01\x33\x04\x00\x00\x0e\x10\x3a\x04\x00\x00\x07" +
+ "\x08\x3b\x04\x00\x00\x0c\x4e\x01\x04\xff\xff\xff\x00\x1c\x04\xc0" +
+ "\xa8\x2b\xff\x03\x04\xc0\xa8\x2b\x01\x06\x04\xc0\xa8\x2b\x01\x2b" +
+ "\x0f\x41\x4e\x44\x52\x4f\x49\x44\x5f\x4d\x45\x54\x45\x52\x45\x44" +
+ "\xff"
+ h := header(v)
+ if !h.isValid() {
+ t.Error("failed to decode header")
+ }
+
+ if got, want := h.op(), opReply; got != want {
+ t.Errorf("h.op()=%v, want=%v", got, want)
+ }
+
+ if _, err := h.options(); err != nil {
+ t.Errorf("bad options: %v", err)
+ }
+}
+
+func teeConn(c conn) (conn, conn) {
+ dup1 := &dupConn{
+ c: c,
+ dup: make(chan connMsg, 8),
+ }
+ dup2 := &chConn{
+ c: c,
+ ch: dup1.dup,
+ }
+ return dup1, dup2
+}
+
+type connMsg struct {
+ buf buffer.View
+ addr tcpip.FullAddress
+ err error
+}
+
+type dupConn struct {
+ c conn
+ dup chan connMsg
+}
+
+func (c *dupConn) Read() (buffer.View, tcpip.FullAddress, error) {
+ v, addr, err := c.c.Read()
+ c.dup <- connMsg{v, addr, err}
+ return v, addr, err
+}
+func (c *dupConn) Write(b []byte, addr *tcpip.FullAddress) error { return c.c.Write(b, addr) }
+
+type chConn struct {
+ ch chan connMsg
+ c conn
+}
+
+func (c *chConn) Read() (buffer.View, tcpip.FullAddress, error) {
+ msg := <-c.ch
+ return msg.buf, msg.addr, msg.err
+}
+func (c *chConn) Write(b []byte, addr *tcpip.FullAddress) error { return c.c.Write(b, addr) }
+
+func TestTwoServers(t *testing.T) {
+ s := createStack(t)
+
+ wq := new(waiter.Queue)
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("dhcp: server endpoint: %v", err)
+ }
+ if err = ep.Bind(tcpip.FullAddress{Port: ServerPort}, nil); err != nil {
+ t.Fatalf("dhcp: server bind: %v", err)
+ }
+
+ serverCtx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ c1, c2 := teeConn(newEPConn(serverCtx, wq, ep))
+
+ if _, err := NewServer(serverCtx, c1, []tcpip.Address{"\xc0\xa8\x03\x02"}, Config{
+ ServerAddress: "\xc0\xa8\x03\x01",
+ SubnetMask: "\xff\xff\xff\x00",
+ Gateway: "\xc0\xa8\x03\xF0",
+ DNS: []tcpip.Address{"\x08\x08\x08\x08"},
+ LeaseLength: 30 * time.Minute,
+ }); err != nil {
+ t.Fatal(err)
+ }
+ if _, err := NewServer(serverCtx, c2, []tcpip.Address{"\xc0\xa8\x04\x02"}, Config{
+ ServerAddress: "\xc0\xa8\x04\x01",
+ SubnetMask: "\xff\xff\xff\x00",
+ Gateway: "\xc0\xa8\x03\xF0",
+ DNS: []tcpip.Address{"\x08\x08\x08\x08"},
+ LeaseLength: 30 * time.Minute,
+ }); err != nil {
+ t.Fatal(err)
+ }
+
+ const clientLinkAddr0 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x52")
+ c := NewClient(s, nicid, clientLinkAddr0, nil)
+ if _, err := c.Request(context.Background(), ""); err != nil {
+ t.Fatal(err)
+ }
+}