summaryrefslogtreecommitdiffhomepage
path: root/pkg/dhcp/client.go
blob: 37deb69fff91f3666cc6d6a05c4c978092195e9e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
// Copyright 2016 The Netstack Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package dhcp

import (
	"bytes"
	"context"
	"crypto/rand"
	"fmt"
	"log"
	"sync"
	"time"

	"gvisor.googlesource.com/gvisor/pkg/tcpip"
	"gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
	"gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
	"gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp"
	"gvisor.googlesource.com/gvisor/pkg/waiter"
)

// Client is a DHCP client.
type Client struct {
	stack    *stack.Stack
	nicid    tcpip.NICID
	linkAddr tcpip.LinkAddress

	mu          sync.Mutex
	addr        tcpip.Address
	cfg         Config
	lease       time.Duration
	cancelRenew func()
}

// NewClient creates a DHCP client.
//
// TODO: add s.LinkAddr(nicid) to *stack.Stack.
func NewClient(s *stack.Stack, nicid tcpip.NICID, linkAddr tcpip.LinkAddress) *Client {
	return &Client{
		stack:    s,
		nicid:    nicid,
		linkAddr: linkAddr,
	}
}

// Start starts the DHCP client.
// It will periodically search for an IP address using the Request method.
func (c *Client) Start() {
	go func() {
		for {
			log.Print("DHCP request")
			ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
			err := c.Request(ctx, "")
			cancel()
			if err == nil {
				break
			}
		}
		log.Printf("DHCP acquired IP %s for %s", c.Address(), c.Config().LeaseLength)
	}()
}

// Address reports the IP address acquired by the DHCP client.
func (c *Client) Address() tcpip.Address {
	c.mu.Lock()
	defer c.mu.Unlock()
	return c.addr
}

// Config reports the DHCP configuration acquired with the IP address lease.
func (c *Client) Config() Config {
	c.mu.Lock()
	defer c.mu.Unlock()
	return c.cfg
}

// Shutdown relinquishes any lease and ends any outstanding renewal timers.
func (c *Client) Shutdown() {
	c.mu.Lock()
	defer c.mu.Unlock()
	if c.addr != "" {
		c.stack.RemoveAddress(c.nicid, c.addr)
	}
	if c.cancelRenew != nil {
		c.cancelRenew()
	}
}

// Request executes a DHCP request session.
//
// On success, it adds a new address to this client's TCPIP stack.
// If the server sets a lease limit a timer is set to automatically
// renew it.
func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error {
	var wq waiter.Queue
	ep, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
	if err != nil {
		return fmt.Errorf("dhcp: outbound endpoint: %v", err)
	}
	err = ep.Bind(tcpip.FullAddress{
		Addr: "\x00\x00\x00\x00",
		Port: clientPort,
	}, nil)
	defer ep.Close()
	if err != nil {
		return fmt.Errorf("dhcp: connect failed: %v", err)
	}

	epin, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
	if err != nil {
		return fmt.Errorf("dhcp: inbound endpoint: %v", err)
	}
	err = epin.Bind(tcpip.FullAddress{
		Addr: "\xff\xff\xff\xff",
		Port: clientPort,
	}, nil)
	defer epin.Close()
	if err != nil {
		return fmt.Errorf("dhcp: connect failed: %v", err)
	}

	var xid [4]byte
	rand.Read(xid[:])

	// DHCPDISCOVERY
	options := options{
		{optDHCPMsgType, []byte{byte(dhcpDISCOVER)}},
		{optParamReq, []byte{
			1,  // request subnet mask
			3,  // request router
			15, // domain name
			6,  // domain name server
		}},
	}
	if requestedAddr != "" {
		options = append(options, option{optReqIPAddr, []byte(requestedAddr)})
	}
	h := make(header, headerBaseSize+options.len())
	h.init()
	h.setOp(opRequest)
	copy(h.xidbytes(), xid[:])
	h.setBroadcast()
	copy(h.chaddr(), c.linkAddr)
	h.setOptions(options)

	serverAddr := &tcpip.FullAddress{
		Addr: "\xff\xff\xff\xff",
		Port: serverPort,
	}
	wopts := tcpip.WriteOptions{
		To: serverAddr,
	}
	if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil {
		return fmt.Errorf("dhcp discovery write: %v", err)
	}

	we, ch := waiter.NewChannelEntry(nil)
	wq.EventRegister(&we, waiter.EventIn)
	defer wq.EventUnregister(&we)

	// DHCPOFFER
	for {
		var addr tcpip.FullAddress
		v, _, err := epin.Read(&addr)
		if err == tcpip.ErrWouldBlock {
			select {
			case <-ch:
				continue
			case <-ctx.Done():
				return fmt.Errorf("reading dhcp offer: %v", tcpip.ErrAborted)
			}
		}
		h = header(v)
		if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) {
			break
		}
	}
	if _, err := h.options(); err != nil {
		return fmt.Errorf("dhcp offer: %v", err)
	}

	var ack bool
	var cfg Config

	// DHCPREQUEST
	addr := tcpip.Address(h.yiaddr())
	if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, addr); err != nil {
		if err != tcpip.ErrDuplicateAddress {
			return fmt.Errorf("adding address: %v", err)
		}
	}
	defer func() {
		if ack {
			c.mu.Lock()
			c.addr = addr
			c.cfg = cfg
			c.mu.Unlock()
		} else {
			c.stack.RemoveAddress(c.nicid, addr)
		}
	}()
	h.setOp(opRequest)
	for i, b := 0, h.yiaddr(); i < len(b); i++ {
		b[i] = 0
	}
	h.setOptions([]option{
		{optDHCPMsgType, []byte{byte(dhcpREQUEST)}},
		{optReqIPAddr, []byte(addr)},
		{optDHCPServer, h.siaddr()},
	})
	if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil {
		return fmt.Errorf("dhcp discovery write: %v", err)
	}

	// DHCPACK
	for {
		var addr tcpip.FullAddress
		v, _, err := epin.Read(&addr)
		if err == tcpip.ErrWouldBlock {
			select {
			case <-ch:
				continue
			case <-ctx.Done():
				return fmt.Errorf("reading dhcp ack: %v", tcpip.ErrAborted)
			}
		}
		h = header(v)
		if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) {
			break
		}
	}
	opts, e := h.options()
	if e != nil {
		return fmt.Errorf("dhcp ack: %v", e)
	}
	if err := cfg.decode(opts); err != nil {
		return fmt.Errorf("dhcp ack bad options: %v", err)
	}
	msgtype, e := opts.dhcpMsgType()
	if e != nil {
		return fmt.Errorf("dhcp ack: %v", e)
	}
	ack = msgtype == dhcpACK
	if !ack {
		return fmt.Errorf("dhcp: request not acknowledged")
	}
	if cfg.LeaseLength != 0 {
		go c.renewAfter(cfg.LeaseLength)
	}
	return nil
}

func (c *Client) renewAfter(d time.Duration) {
	c.mu.Lock()
	defer c.mu.Unlock()
	if c.cancelRenew != nil {
		c.cancelRenew()
	}
	ctx, cancel := context.WithCancel(context.Background())
	c.cancelRenew = cancel
	go func() {
		timer := time.NewTimer(d)
		defer timer.Stop()
		select {
		case <-ctx.Done():
		case <-timer.C:
			if err := c.Request(ctx, c.addr); err != nil {
				log.Printf("address renewal failed: %v", err)
				go c.renewAfter(1 * time.Minute)
			}
		}
	}()
}