summaryrefslogtreecommitdiffhomepage
path: root/dhcpv4/server_test.go
blob: d4557860a182c867d143fd5a9c2c53b963756a43 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
// +build integration

package dhcpv4

import (
	"errors"
	"log"
	"math/rand"
	"net"
	"testing"
	"time"

	"github.com/stretchr/testify/require"
)

func init() {
	// initialize seed. This is generally bad, but "good enough"
	// to generate random ports for these tests
	rand.Seed(time.Now().UTC().UnixNano())
}

func randPort() int {
	// can't use port 0 with raw sockets, so until we implement
	// a non-raw-sockets client for non-static ports, we have to
	// deal with this "randomness"
	return 1024 + rand.Intn(65536-1024)
}

// DORAHandler is a server handler suitable for DORA transactions
func DORAHandler(conn net.PacketConn, peer net.Addr, m *DHCPv4) {
	if m == nil {
		log.Printf("Packet is nil!")
		return
	}
	if m.Opcode() != OpcodeBootRequest {
		log.Printf("Not a BootRequest!")
		return
	}
	reply, err := NewReplyFromRequest(m)
	if err != nil {
		log.Printf("NewReplyFromRequest failed: %v", err)
		return
	}
	reply.AddOption(&OptServerIdentifier{ServerID: net.IP{1, 2, 3, 4}})
	opt := m.GetOneOption(OptionDHCPMessageType)
	if opt == nil {
		log.Printf("No message type found!")
		return
	}
	switch opt.(*OptMessageType).MessageType {
	case MessageTypeDiscover:
		reply.AddOption(&OptMessageType{MessageType: MessageTypeOffer})
	case MessageTypeRequest:
		reply.AddOption(&OptMessageType{MessageType: MessageTypeAck})
	default:
		log.Printf("Unhandled message type: %v", opt.(*OptMessageType).MessageType)
		return
	}

	if _, err := conn.WriteTo(reply.ToBytes(), peer); err != nil {
		log.Printf("Cannot reply to client: %v", err)
	}
}

// utility function to set up a client and a server instance and run it in
// background. The caller needs to call Server.Close() once finished.
func setUpClientAndServer(handler Handler) (*Client, *Server) {
	// strong assumption, I know
	loAddr := net.ParseIP("127.0.0.1")
	laddr := net.UDPAddr{
		IP:   loAddr,
		Port: randPort(),
	}
	s := NewServer(laddr, handler)
	go s.ActivateAndServe()

	c := NewClient()
	// FIXME this doesn't deal well with raw sockets, the actual 0 will be used
	// in the UDP header as source port
	c.LocalAddr = &net.UDPAddr{IP: loAddr, Port: randPort()}
	for {
		if s.LocalAddr() != nil {
			break
		}
		time.Sleep(10 * time.Millisecond)
		log.Printf("Waiting for server to run...")
	}
	c.RemoteAddr = s.LocalAddr()
	log.Printf("Client.RemoteAddr: %s", c.RemoteAddr)

	return c, s
}

// utility function to return the loopback interface name
// TODO this is copied from dhcpv6/server_test.go , we should refactor common code in a separate package
func getLoopbackInterface() (string, error) {
	var ifaces []net.Interface
	var err error
	if ifaces, err = net.Interfaces(); err != nil {
		return "", err
	}
	for _, iface := range ifaces {
		if iface.Flags&net.FlagLoopback != 0 || iface.Name[:2] == "lo" {
			return iface.Name, nil
		}
	}
	return "", errors.New("No loopback interface found")
}

func TestNewServer(t *testing.T) {
	laddr := net.UDPAddr{
		IP:   net.ParseIP("127.0.0.1"),
		Port: 0,
	}
	s := NewServer(laddr, DORAHandler)
	defer s.Close()

	require.NotNil(t, s)
	require.Nil(t, s.conn)
	require.Equal(t, laddr, s.localAddr)
	require.NotNil(t, s.Handler)
}

func TestServerActivateAndServe(t *testing.T) {
	c, s := setUpClientAndServer(DORAHandler)
	defer s.Close()

	lo, err := getLoopbackInterface()
	require.NoError(t, err)

	xid := uint32(0xaabbccdd)
	hwaddr := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}

	modifiers := []Modifier{
		WithTransactionID(xid),
		WithHwAddr(hwaddr[:]),
	}

	conv, err := c.Exchange(lo, nil, modifiers...)
	require.NoError(t, err)
	require.Equal(t, 4, len(conv))
	for _, p := range conv {
		require.Equal(t, xid, p.TransactionID())
		require.Equal(t, [16]byte(hwaddr), p.ClientHwAddr())
	}
}