diff options
author | Sean Karlage <skarlage@fb.com> | 2018-08-15 10:34:04 -0700 |
---|---|---|
committer | Sean Karlage <skarlage@fb.com> | 2018-08-15 10:34:04 -0700 |
commit | 40169a2169f788c41cb806c9d344148e72a3a0bd (patch) | |
tree | 6cfee5029c946710e4ed7789151ac02e62fa852b /dhcpv6 | |
parent | 8ea2525c898436a2a935580de67727bbe7035c85 (diff) | |
parent | 926a42d133247d7a4fa388548e4323b77421f798 (diff) |
Merge branch 'master' into dhcpv4-moar-tests
Diffstat (limited to 'dhcpv6')
-rw-r--r-- | dhcpv6/async/client.go (renamed from dhcpv6/async.go) | 101 | ||||
-rw-r--r-- | dhcpv6/async/client_test.go | 153 | ||||
-rw-r--r-- | dhcpv6/async_test.go | 54 | ||||
-rw-r--r-- | dhcpv6/client.go | 4 | ||||
-rw-r--r-- | dhcpv6/dhcpv6.go | 58 | ||||
-rw-r--r-- | dhcpv6/dhcpv6_test.go | 47 | ||||
-rw-r--r-- | dhcpv6/dhcpv6message.go | 27 | ||||
-rw-r--r-- | dhcpv6/dhcpv6message_test.go | 34 | ||||
-rw-r--r-- | dhcpv6/dhcpv6relay.go | 25 | ||||
-rw-r--r-- | dhcpv6/dhcpv6relay_test.go | 22 | ||||
-rw-r--r-- | dhcpv6/future.go | 111 | ||||
-rw-r--r-- | dhcpv6/future_test.go | 170 | ||||
-rw-r--r-- | dhcpv6/modifiers.go | 6 | ||||
-rw-r--r-- | dhcpv6/option_archtype.go | 62 | ||||
-rw-r--r-- | dhcpv6/option_archtype_test.go | 5 | ||||
-rw-r--r-- | dhcpv6/utils.go | 77 | ||||
-rw-r--r-- | dhcpv6/utils_test.go | 69 |
17 files changed, 425 insertions, 600 deletions
diff --git a/dhcpv6/async.go b/dhcpv6/async/client.go index e9930bf..08c2cfb 100644 --- a/dhcpv6/async.go +++ b/dhcpv6/async/client.go @@ -1,4 +1,4 @@ -package dhcpv6 +package async import ( "context" @@ -6,10 +6,13 @@ import ( "net" "sync" "time" + + "github.com/fanliao/go-promise" + "github.com/insomniacslk/dhcp/dhcpv6" ) -// AsyncClient implements an asynchronous DHCPv6 client -type AsyncClient struct { +// Client implements an asynchronous DHCPv6 client +type Client struct { ReadTimeout time.Duration WriteTimeout time.Duration LocalAddr net.Addr @@ -19,35 +22,35 @@ type AsyncClient struct { connection *net.UDPConn cancel context.CancelFunc stopping *sync.WaitGroup - receiveQueue chan DHCPv6 - sendQueue chan DHCPv6 + receiveQueue chan dhcpv6.DHCPv6 + sendQueue chan dhcpv6.DHCPv6 packetsLock sync.Mutex - packets map[uint32](chan Response) + packets map[uint32]*promise.Promise errors chan error } -// NewAsyncClient creates an asynchronous client -func NewAsyncClient() *AsyncClient { - return &AsyncClient{ - ReadTimeout: DefaultReadTimeout, - WriteTimeout: DefaultWriteTimeout, +// NewClient creates an asynchronous client +func NewClient() *Client { + return &Client{ + ReadTimeout: dhcpv6.DefaultReadTimeout, + WriteTimeout: dhcpv6.DefaultWriteTimeout, } } // OpenForInterface starts the client on the specified interface, replacing // client LocalAddr with a link-local address of the given interface and // standard DHCP port (546). -func (c *AsyncClient) OpenForInterface(ifname string, bufferSize int) error { - addr, err := GetLinkLocalAddr(ifname) +func (c *Client) OpenForInterface(ifname string, bufferSize int) error { + addr, err := dhcpv6.GetLinkLocalAddr(ifname) if err != nil { return err } - c.LocalAddr = &net.UDPAddr{IP: *addr, Port: DefaultClientPort, Zone: ifname} + c.LocalAddr = &net.UDPAddr{IP: *addr, Port: dhcpv6.DefaultClientPort, Zone: ifname} return c.Open(bufferSize) } // Open starts the client -func (c *AsyncClient) Open(bufferSize int) error { +func (c *Client) Open(bufferSize int) error { var ( addr *net.UDPAddr ok bool @@ -64,9 +67,9 @@ func (c *AsyncClient) Open(bufferSize int) error { return err } c.stopping = new(sync.WaitGroup) - c.sendQueue = make(chan DHCPv6, bufferSize) - c.receiveQueue = make(chan DHCPv6, bufferSize) - c.packets = make(map[uint32](chan Response)) + c.sendQueue = make(chan dhcpv6.DHCPv6, bufferSize) + c.receiveQueue = make(chan dhcpv6.DHCPv6, bufferSize) + c.packets = make(map[uint32]*promise.Promise) c.packetsLock = sync.Mutex{} c.errors = make(chan error) @@ -79,7 +82,7 @@ func (c *AsyncClient) Open(bufferSize int) error { } // Close stops the client -func (c *AsyncClient) Close() { +func (c *Client) Close() { // Wait for sender and receiver loops c.stopping.Add(2) c.cancel() @@ -93,17 +96,17 @@ func (c *AsyncClient) Close() { } // Errors returns a channel where runtime errors are posted -func (c *AsyncClient) Errors() <-chan error { +func (c *Client) Errors() <-chan error { return c.errors } -func (c *AsyncClient) addError(err error) { +func (c *Client) addError(err error) { if !c.IgnoreErrors { c.errors <- err } } -func (c *AsyncClient) receiverLoop(ctx context.Context) { +func (c *Client) receiverLoop(ctx context.Context) { defer func() { c.stopping.Done() }() for { select { @@ -115,7 +118,7 @@ func (c *AsyncClient) receiverLoop(ctx context.Context) { } } -func (c *AsyncClient) senderLoop(ctx context.Context) { +func (c *Client) senderLoop(ctx context.Context) { defer func() { c.stopping.Done() }() for { select { @@ -127,41 +130,41 @@ func (c *AsyncClient) senderLoop(ctx context.Context) { } } -func (c *AsyncClient) send(packet DHCPv6) { - transactionID, err := GetTransactionID(packet) +func (c *Client) send(packet dhcpv6.DHCPv6) { + transactionID, err := dhcpv6.GetTransactionID(packet) if err != nil { c.addError(fmt.Errorf("Warning: This should never happen, there is no transaction ID on %s", packet)) return } c.packetsLock.Lock() - f := c.packets[transactionID] + p := c.packets[transactionID] c.packetsLock.Unlock() raddr, err := c.remoteAddr() if err != nil { - f <- NewResponse(nil, err) + p.Reject(err) return } c.connection.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) _, err = c.connection.WriteTo(packet.ToBytes(), raddr) if err != nil { - f <- NewResponse(nil, err) + p.Reject(err) return } c.receiveQueue <- packet } -func (c *AsyncClient) receive(_ DHCPv6) { +func (c *Client) receive(_ dhcpv6.DHCPv6) { var ( oobdata = []byte{} - received DHCPv6 + received dhcpv6.DHCPv6 ) c.connection.SetReadDeadline(time.Now().Add(c.ReadTimeout)) for { - buffer := make([]byte, maxUDPReceivedPacketSize) + buffer := make([]byte, dhcpv6.MaxUDPReceivedPacketSize) n, _, _, _, err := c.connection.ReadMsgUDP(buffer, oobdata) if err != nil { if err, ok := err.(net.Error); !ok || !err.Timeout() { @@ -169,7 +172,7 @@ func (c *AsyncClient) receive(_ DHCPv6) { } return } - received, err = FromBytes(buffer[:n]) + received, err = dhcpv6.FromBytes(buffer[:n]) if err != nil { // skip non-DHCP packets continue @@ -177,23 +180,23 @@ func (c *AsyncClient) receive(_ DHCPv6) { break } - transactionID, err := GetTransactionID(received) + transactionID, err := dhcpv6.GetTransactionID(received) if err != nil { c.addError(fmt.Errorf("Unable to get a transactionID for %s: %s", received, err)) return } c.packetsLock.Lock() - if f, ok := c.packets[transactionID]; ok { + if p, ok := c.packets[transactionID]; ok { delete(c.packets, transactionID) - f <- NewResponse(received, nil) + p.Resolve(received) } c.packetsLock.Unlock() } -func (c *AsyncClient) remoteAddr() (*net.UDPAddr, error) { +func (c *Client) remoteAddr() (*net.UDPAddr, error) { if c.RemoteAddr == nil { - return &net.UDPAddr{IP: AllDHCPRelayAgentsAndServers, Port: DefaultServerPort}, nil + return &net.UDPAddr{IP: dhcpv6.AllDHCPRelayAgentsAndServers, Port: dhcpv6.DefaultServerPort}, nil } if addr, ok := c.RemoteAddr.(*net.UDPAddr); ok { @@ -204,32 +207,20 @@ func (c *AsyncClient) remoteAddr() (*net.UDPAddr, error) { // Send inserts a message to the queue to be sent asynchronously. // Returns a future which resolves to response and error. -func (c *AsyncClient) Send(message DHCPv6, modifiers ...Modifier) Future { +func (c *Client) Send(message dhcpv6.DHCPv6, modifiers ...dhcpv6.Modifier) *promise.Future { for _, mod := range modifiers { message = mod(message) } - transactionID, err := GetTransactionID(message) + transactionID, err := dhcpv6.GetTransactionID(message) if err != nil { - return NewFailureFuture(err) + return promise.Wrap(err) } - f := NewFuture() + p := promise.NewPromise() c.packetsLock.Lock() - c.packets[transactionID] = f + c.packets[transactionID] = p c.packetsLock.Unlock() c.sendQueue <- message - return f -} - -// Exchange executes asynchronously a 4-way DHCPv6 request (SOLICIT, -// ADVERTISE, REQUEST, REPLY). -func (c *AsyncClient) Exchange(solicit DHCPv6, modifiers ...Modifier) Future { - return c.Send(solicit).OnSuccess(func(advertise DHCPv6) Future { - request, err := NewRequestFromAdvertise(advertise) - if err != nil { - return NewFailureFuture(err) - } - return c.Send(request, modifiers...) - }) + return p.Future } diff --git a/dhcpv6/async/client_test.go b/dhcpv6/async/client_test.go new file mode 100644 index 0000000..0bc3a87 --- /dev/null +++ b/dhcpv6/async/client_test.go @@ -0,0 +1,153 @@ +package async + +import ( + "context" + "net" + "testing" + "time" + + "github.com/insomniacslk/dhcp/dhcpv6" + "github.com/insomniacslk/dhcp/iana" + "github.com/stretchr/testify/require" +) + +const retries = 5 + +// solicit creates new solicit based on the mac address +func solicit(input string) (dhcpv6.DHCPv6, error) { + mac, err := net.ParseMAC(input) + if err != nil { + return nil, err + } + duid := dhcpv6.Duid{ + Type: dhcpv6.DUID_LLT, + HwType: iana.HwTypeEthernet, + Time: dhcpv6.GetTime(), + LinkLayerAddr: mac, + } + return dhcpv6.NewSolicitWithCID(duid) +} + +// server creates a server which responds with a predefined response +func serve(ctx context.Context, addr *net.UDPAddr, response dhcpv6.DHCPv6) error { + conn, err := net.ListenUDP("udp6", addr) + if err != nil { + return err + } + go func() { + defer conn.Close() + oobdata := []byte{} + buffer := make([]byte, dhcpv6.MaxUDPReceivedPacketSize) + for { + select { + case <-ctx.Done(): + return + default: + conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + n, _, _, src, err := conn.ReadMsgUDP(buffer, oobdata) + if err != nil { + continue + } + _, err = dhcpv6.FromBytes(buffer[:n]) + if err != nil { + continue + } + conn.SetWriteDeadline(time.Now().Add(1 * time.Second)) + _, err = conn.WriteTo(response.ToBytes(), src) + if err != nil { + continue + } + } + } + }() + return nil +} + +func TestNewClient(t *testing.T) { + c := NewClient() + require.NotNil(t, c) + require.Equal(t, c.ReadTimeout, dhcpv6.DefaultReadTimeout) + require.Equal(t, c.ReadTimeout, dhcpv6.DefaultWriteTimeout) +} + +func TestOpenInvalidAddrFailes(t *testing.T) { + c := NewClient() + err := c.Open(512) + require.Error(t, err) +} + +// This test uses port 15438 so please make sure its not used before running +func TestOpenClose(t *testing.T) { + c := NewClient() + addr, err := net.ResolveUDPAddr("udp6", ":15438") + require.NoError(t, err) + c.LocalAddr = addr + err = c.Open(512) + require.NoError(t, err) + defer c.Close() +} + +// This test uses ports 15438 and 15439 so please make sure they are not used +// before running +func TestSendTimeout(t *testing.T) { + c := NewClient() + addr, err := net.ResolveUDPAddr("udp6", ":15438") + require.NoError(t, err) + remote, err := net.ResolveUDPAddr("udp6", ":15439") + require.NoError(t, err) + c.ReadTimeout = 50 * time.Millisecond + c.WriteTimeout = 50 * time.Millisecond + c.LocalAddr = addr + c.RemoteAddr = remote + err = c.Open(512) + require.NoError(t, err) + defer c.Close() + m, err := dhcpv6.NewMessage() + require.NoError(t, err) + _, err, timeout := c.Send(m).GetOrTimeout(200) + require.NoError(t, err) + require.True(t, timeout) +} + +// This test uses ports 15438 and 15439 so please make sure they are not used +// before running +func TestSend(t *testing.T) { + s, err := solicit("c8:6c:2c:47:96:fd") + require.NoError(t, err) + require.NotNil(t, s) + + a, err := dhcpv6.NewAdvertiseFromSolicit(s) + require.NoError(t, err) + require.NotNil(t, a) + + c := NewClient() + addr, err := net.ResolveUDPAddr("udp6", ":15438") + require.NoError(t, err) + remote, err := net.ResolveUDPAddr("udp6", ":15439") + require.NoError(t, err) + c.LocalAddr = addr + c.RemoteAddr = remote + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + err = serve(ctx, remote, a) + require.NoError(t, err) + + err = c.Open(16) + require.NoError(t, err) + defer c.Close() + + f := c.Send(s) + + var passed bool + for i := 0; i < retries; i++ { + response, err, timeout := f.GetOrTimeout(1000) + if timeout { + continue + } + passed = true + require.NoError(t, err) + require.Equal(t, a, response) + } + require.True(t, passed, "All attempts to TestSend timed out") +} diff --git a/dhcpv6/async_test.go b/dhcpv6/async_test.go deleted file mode 100644 index 4f5a750..0000000 --- a/dhcpv6/async_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package dhcpv6 - -import ( - "net" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -func TestNewAsyncClient(t *testing.T) { - c := NewAsyncClient() - require.NotNil(t, c) - require.Equal(t, c.ReadTimeout, DefaultReadTimeout) - require.Equal(t, c.ReadTimeout, DefaultWriteTimeout) -} - -func TestOpenInvalidAddrFailes(t *testing.T) { - c := NewAsyncClient() - err := c.Open(512) - require.Error(t, err) -} - -// This test uses port 15438 so please make sure its not used before running -func TestOpenClose(t *testing.T) { - c := NewAsyncClient() - addr, err := net.ResolveUDPAddr("udp6", ":15438") - require.NoError(t, err) - c.LocalAddr = addr - err = c.Open(512) - require.NoError(t, err) - defer c.Close() -} - -// This test uses ports 15438 and 15439 so please make sure they are not used -// before running -func TestSendTimeout(t *testing.T) { - c := NewAsyncClient() - addr, err := net.ResolveUDPAddr("udp6", ":15438") - require.NoError(t, err) - remote, err := net.ResolveUDPAddr("udp6", ":15439") - require.NoError(t, err) - c.ReadTimeout = 50 * time.Millisecond - c.WriteTimeout = 50 * time.Millisecond - c.LocalAddr = addr - c.RemoteAddr = remote - err = c.Open(512) - require.NoError(t, err) - defer c.Close() - m, err := NewMessage() - require.NoError(t, err) - _, err = c.Send(m).WaitTimeout(200 * time.Millisecond) - require.Error(t, err) -} diff --git a/dhcpv6/client.go b/dhcpv6/client.go index 3492a47..3ed7861 100644 --- a/dhcpv6/client.go +++ b/dhcpv6/client.go @@ -11,7 +11,7 @@ const ( DefaultWriteTimeout = 3 * time.Second // time to wait for write calls DefaultReadTimeout = 3 * time.Second // time to wait for read calls DefaultInterfaceUpTimeout = 3 * time.Second // time to wait before a network interface goes up - maxUDPReceivedPacketSize = 8192 // arbitrary size. Theoretically could be up to 65kb + MaxUDPReceivedPacketSize = 8192 // arbitrary size. Theoretically could be up to 65kb ) // Broadcast destination IP addresses as defined by RFC 3315 @@ -148,7 +148,7 @@ func (c *Client) sendReceive(ifname string, packet DHCPv6, expectedType MessageT isMessage = true } for { - buf := make([]byte, maxUDPReceivedPacketSize) + buf := make([]byte, MaxUDPReceivedPacketSize) n, _, _, _, err := conn.ReadMsgUDP(buf, oobdata) if err != nil { return nil, err diff --git a/dhcpv6/dhcpv6.go b/dhcpv6/dhcpv6.go index 0dabca5..9382334 100644 --- a/dhcpv6/dhcpv6.go +++ b/dhcpv6/dhcpv6.go @@ -1,8 +1,12 @@ package dhcpv6 import ( + "errors" "fmt" "net" + "strings" + + "github.com/insomniacslk/dhcp/iana" ) type DHCPv6 interface { @@ -199,3 +203,57 @@ func EncapsulateRelay(d DHCPv6, mType MessageType, linkAddr, peerAddr net.IP) (D outer.AddOption(&orm) return &outer, nil } + +// IsUsingUEFI function takes a DHCPv6 message and returns true if +// the machine trying to netboot is using UEFI of false if it is not. +func IsUsingUEFI(msg DHCPv6) bool { + // RFC 4578 says: + // As of the writing of this document, the following pre-boot + // architecture types have been requested. + // Type Architecture Name + // ---- ----------------- + // 0 Intel x86PC + // 1 NEC/PC98 + // 2 EFI Itanium + // 3 DEC Alpha + // 4 Arc x86 + // 5 Intel Lean Client + // 6 EFI IA32 + // 7 EFI BC + // 8 EFI Xscale + // 9 EFI x86-64 + if opt := msg.GetOneOption(OptionClientArchType); opt != nil { + optat := opt.(*OptClientArchType) + for _, at := range optat.ArchTypes { + // TODO investigate if other types are appropriate + if at == iana.EFI_BC || at == iana.EFI_X86_64 { + return true + } + } + } + if opt := msg.GetOneOption(OptionUserClass); opt != nil { + optuc := opt.(*OptUserClass) + for _, uc := range optuc.UserClasses { + if strings.Contains(string(uc), "EFI") { + return true + } + } + } + return false +} + +// GetTransactionID returns a transactionID of a message or its inner message +// in case of relay +func GetTransactionID(packet DHCPv6) (uint32, error) { + if message, ok := packet.(*DHCPv6Message); ok { + return message.TransactionID(), nil + } + if relay, ok := packet.(*DHCPv6Relay); ok { + message, err := relay.GetInnerMessage() + if err != nil { + return 0, err + } + return GetTransactionID(message) + } + return 0, errors.New("Invalid DHCPv6 packet") +} diff --git a/dhcpv6/dhcpv6_test.go b/dhcpv6/dhcpv6_test.go index 4a44e0e..6840252 100644 --- a/dhcpv6/dhcpv6_test.go +++ b/dhcpv6/dhcpv6_test.go @@ -215,5 +215,52 @@ func TestNewMessageTypeSolicitWithCID(t *testing.T) { require.Equal(t, len(opts), 2) } + +func TestIsUsingUEFIArchTypeTrue(t *testing.T) { + msg := DHCPv6Message{} + opt := OptClientArchType{ArchTypes: []iana.ArchType{iana.EFI_BC}} + msg.AddOption(&opt) + require.True(t, IsUsingUEFI(&msg)) +} + +func TestIsUsingUEFIArchTypeFalse(t *testing.T) { + msg := DHCPv6Message{} + opt := OptClientArchType{ArchTypes: []iana.ArchType{iana.INTEL_X86PC}} + msg.AddOption(&opt) + require.False(t, IsUsingUEFI(&msg)) +} + +func TestIsUsingUEFIUserClassTrue(t *testing.T) { + msg := DHCPv6Message{} + opt := OptUserClass{UserClasses: [][]byte{[]byte("ipxeUEFI")}} + msg.AddOption(&opt) + require.True(t, IsUsingUEFI(&msg)) +} + +func TestIsUsingUEFIUserClassFalse(t *testing.T) { + msg := DHCPv6Message{} + opt := OptUserClass{UserClasses: [][]byte{[]byte("ipxeLegacy")}} + msg.AddOption(&opt) + require.False(t, IsUsingUEFI(&msg)) +} + +func TestGetTransactionIDMessage(t *testing.T) { + message, err := NewMessage() + require.NoError(t, err) + transactionID, err := GetTransactionID(message) + require.NoError(t, err) + require.Equal(t, transactionID, message.(*DHCPv6Message).TransactionID()) +} + +func TestGetTransactionIDRelay(t *testing.T) { + message, err := NewMessage() + require.NoError(t, err) + relay, err := EncapsulateRelay(message, MessageTypeRelayForward, nil, nil) + require.NoError(t, err) + transactionID, err := GetTransactionID(relay) + require.NoError(t, err) + require.Equal(t, transactionID, message.(*DHCPv6Message).TransactionID()) +} + // TODO test NewMessageTypeSolicit // test String and Summary diff --git a/dhcpv6/dhcpv6message.go b/dhcpv6/dhcpv6message.go index e601932..7ee00ad 100644 --- a/dhcpv6/dhcpv6message.go +++ b/dhcpv6/dhcpv6message.go @@ -286,6 +286,33 @@ func (d *DHCPv6Message) UpdateOption(option Option) { d.AddOption(option) } +// IsNetboot returns true if the machine is trying to netboot. It checks if +// "boot file" is one of the requested options, which is useful for +// SOLICIT/REQUEST packet types, it also checks if the "boot file" option is +// included in the packet, which is useful for ADVERTISE/REPLY packet. +func (d *DHCPv6Message) IsNetboot() bool { + if d.IsOptionRequested(OptionBootfileURL) { + return true + } + if optbf := d.GetOneOption(OptionBootfileURL); optbf != nil { + return true + } + return false +} + +// IsOptionRequested takes an OptionCode and returns true if that option is +// within the requested options of the DHCPv6 message. +func (d *DHCPv6Message) IsOptionRequested(requested OptionCode) bool { + for _, optoro := range d.GetOption(OptionORO) { + for _, o := range optoro.(*OptRequestedOption).RequestedOptions() { + if o == requested { + return true + } + } + } + return false +} + func (d *DHCPv6Message) String() string { return fmt.Sprintf("DHCPv6Message(messageType=%v transactionID=0x%06x, %d options)", d.MessageTypeToString(), d.TransactionID(), len(d.options), diff --git a/dhcpv6/dhcpv6message_test.go b/dhcpv6/dhcpv6message_test.go new file mode 100644 index 0000000..5c92a7b --- /dev/null +++ b/dhcpv6/dhcpv6message_test.go @@ -0,0 +1,34 @@ +package dhcpv6 + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsNetboot(t *testing.T) { + msg1 := DHCPv6Message{} + require.False(t, msg1.IsNetboot()) + + msg2 := DHCPv6Message{} + optro := OptRequestedOption{} + optro.AddRequestedOption(OptionBootfileURL) + msg2.AddOption(&optro) + require.True(t, msg2.IsNetboot()) + + msg3 := DHCPv6Message{} + optbf := OptBootFileURL{} + msg3.AddOption(&optbf) + require.True(t, msg3.IsNetboot()) +} + +func TestIsOptionRequested(t *testing.T) { + msg1 := DHCPv6Message{} + require.False(t, msg1.IsOptionRequested(OptionDNSRecursiveNameServer)) + + msg2 := DHCPv6Message{} + optro := OptRequestedOption{} + optro.AddRequestedOption(OptionDNSRecursiveNameServer) + msg2.AddOption(&optro) + require.True(t, msg2.IsOptionRequested(OptionDNSRecursiveNameServer)) +} diff --git a/dhcpv6/dhcpv6relay.go b/dhcpv6/dhcpv6relay.go index d9555cb..5c30bad 100644 --- a/dhcpv6/dhcpv6relay.go +++ b/dhcpv6/dhcpv6relay.go @@ -158,23 +158,26 @@ func (d *DHCPv6Relay) GetInnerMessage() (DHCPv6, error) { } } -// NewRelayReplFromRelayForw creates a RELAY_REPL packet based on a RELAY_FORW -// packet and replaces the inner message with the passed DHCPv6 message. +// NewRelayReplFromRelayForw creates a MessageTypeRelayReply based on a +// MessageTypeRelayForward and replaces the inner message with the passed +// DHCPv6 message. It copies the OptionInterfaceID and OptionRemoteID if the +// options are present in the Relay packet. func NewRelayReplFromRelayForw(relayForw, msg DHCPv6) (DHCPv6, error) { var ( err error linkAddr, peerAddr []net.IP - optiids []Option + optiid []Option + optrid []Option ) if relayForw == nil { - return nil, errors.New("RELAY_FORW cannot be nil") + return nil, errors.New("Relay message cannot be nil") } relay, ok := relayForw.(*DHCPv6Relay) if !ok { return nil, errors.New("Not a DHCPv6Relay") } if relay.Type() != MessageTypeRelayForward { - return nil, errors.New("The passed packet is not of type RELAY_FORW") + return nil, errors.New("The passed packet is not of type MessageTypeRelayForward") } if msg == nil { return nil, errors.New("The passed message cannot be nil") @@ -185,7 +188,8 @@ func NewRelayReplFromRelayForw(relayForw, msg DHCPv6) (DHCPv6, error) { for { linkAddr = append(linkAddr, relay.LinkAddr()) peerAddr = append(peerAddr, relay.PeerAddr()) - optiids = append(optiids, relay.GetOneOption(OptionInterfaceID)) + optiid = append(optiid, relay.GetOneOption(OptionInterfaceID)) + optrid = append(optrid, relay.GetOneOption(OptionRemoteID)) decap, err := DecapsulateRelay(relay) if err != nil { return nil, err @@ -198,12 +202,15 @@ func NewRelayReplFromRelayForw(relayForw, msg DHCPv6) (DHCPv6, error) { } for i := len(linkAddr) - 1; i >= 0; i-- { msg, err = EncapsulateRelay(msg, MessageTypeRelayReply, linkAddr[i], peerAddr[i]) - if opt := optiids[i]; opt != nil { - msg.AddOption(opt) - } if err != nil { return nil, err } + if opt := optiid[i]; opt != nil { + msg.AddOption(opt) + } + if opt := optrid[i]; opt != nil { + msg.AddOption(opt) + } } return msg, nil } diff --git a/dhcpv6/dhcpv6relay_test.go b/dhcpv6/dhcpv6relay_test.go index afb4086..fe1b840 100644 --- a/dhcpv6/dhcpv6relay_test.go +++ b/dhcpv6/dhcpv6relay_test.go @@ -108,19 +108,23 @@ func TestDHCPv6RelayToBytes(t *testing.T) { } func TestNewRelayRepFromRelayForw(t *testing.T) { + // create a new relay forward rf := DHCPv6Relay{} rf.SetMessageType(MessageTypeRelayForward) rf.SetPeerAddr(net.IPv6linklocalallrouters) rf.SetLinkAddr(net.IPv6interfacelocalallnodes) - oro := OptRelayMsg{} - s := DHCPv6Message{} - s.SetMessage(MessageTypeSolicit) - cid := OptClientId{} - s.AddOption(&cid) - oro.SetRelayMessage(&s) - rf.AddOption(&oro) + rf.AddOption(&OptInterfaceId{}) + rf.AddOption(&OptRemoteId{}) - a, err := NewAdvertiseFromSolicit(&s) + // create the inner message + s, err := NewMessage() + require.NoError(t, err) + s.AddOption(&OptClientId{}) + orm := OptRelayMsg{} + orm.SetRelayMessage(s) + rf.AddOption(&orm) + + a, err := NewAdvertiseFromSolicit(s) require.NoError(t, err) rr, err := NewRelayReplFromRelayForw(&rf, a) require.NoError(t, err) @@ -129,6 +133,8 @@ func TestNewRelayRepFromRelayForw(t *testing.T) { require.Equal(t, relay.HopCount(), rf.HopCount()) require.Equal(t, relay.PeerAddr(), rf.PeerAddr()) require.Equal(t, relay.LinkAddr(), rf.LinkAddr()) + require.NotNil(t, rr.GetOneOption(OptionInterfaceID)) + require.NotNil(t, rr.GetOneOption(OptionRemoteID)) m, err := relay.GetInnerMessage() require.NoError(t, err) require.Equal(t, m, a) diff --git a/dhcpv6/future.go b/dhcpv6/future.go deleted file mode 100644 index d0ae6cd..0000000 --- a/dhcpv6/future.go +++ /dev/null @@ -1,111 +0,0 @@ -package dhcpv6 - -import ( - "errors" - "time" -) - -// Response represents a value which Future resolves to -type Response interface { - Value() DHCPv6 - Error() error -} - -// Future is a result of an asynchronous DHCPv6 call -type Future (<-chan Response) - -// SuccessFun can be used as a success callback -type SuccessFun func(val DHCPv6) Future - -// FailureFun can be used as a failure callback -type FailureFun func(err error) Future - -type response struct { - val DHCPv6 - err error -} - -func (r *response) Value() DHCPv6 { - return r.val -} - -func (r *response) Error() error { - return r.err -} - -// NewFuture creates a new future, which can be written to -func NewFuture() chan Response { - return make(chan Response, 1) -} - -// NewResponse creates a new future response -func NewResponse(val DHCPv6, err error) Response { - return &response{val: val, err: err} -} - -// NewSuccessFuture creates a future that resolves to a value -func NewSuccessFuture(val DHCPv6) Future { - f := NewFuture() - go func() { - f <- NewResponse(val, nil) - }() - return f -} - -// NewFailureFuture creates a future that resolves to an error -func NewFailureFuture(err error) Future { - f := NewFuture() - go func() { - f <- NewResponse(nil, err) - }() - return f -} - -// Then allows to chain the futures executing appropriate function depending -// on the previous future value -func (f Future) Then(success SuccessFun, failure FailureFun) Future { - g := NewFuture() - go func() { - r := <-f - if r.Error() != nil { - r = <-failure(r.Error()) - g <- r - } else { - r = <-success(r.Value()) - g <- r - } - }() - return g -} - -// OnSuccess allows to chain the futures executing next one only if the first -// one succeeds -func (f Future) OnSuccess(success SuccessFun) Future { - return f.Then(success, func(err error) Future { - return NewFailureFuture(err) - }) -} - -// OnFailure allows to chain the futures executing next one only if the first -// one fails -func (f Future) OnFailure(failure FailureFun) Future { - return f.Then(func(val DHCPv6) Future { - return NewSuccessFuture(val) - }, failure) -} - -// Wait blocks the execution until a future resolves -func (f Future) Wait() (DHCPv6, error) { - r := <-f - return r.Value(), r.Error() -} - -// WaitTimeout blocks the execution until a future resolves or times out -func (f Future) WaitTimeout(timeout time.Duration) (DHCPv6, error) { - select { - case r := <-f: - return r.Value(), r.Error() - case <-time.After(timeout): - return nil, errors.New("Timed out") - } -} diff --git a/dhcpv6/future_test.go b/dhcpv6/future_test.go deleted file mode 100644 index bee87e3..0000000 --- a/dhcpv6/future_test.go +++ /dev/null @@ -1,170 +0,0 @@ -package dhcpv6 - -import ( - "errors" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -func TestResponseValue(t *testing.T) { - m, err := NewMessage() - require.NoError(t, err) - r := NewResponse(m, nil) - require.Equal(t, r.Value(), m) - require.Equal(t, r.Error(), nil) -} - -func TestResponseError(t *testing.T) { - e := errors.New("Test error") - r := NewResponse(nil, e) - require.Equal(t, r.Value(), nil) - require.Equal(t, r.Error(), e) -} - -func TestSuccessFuture(t *testing.T) { - m, err := NewMessage() - require.NoError(t, err) - f := NewSuccessFuture(m) - - val, err := f.Wait() - require.NoError(t, err) - require.Equal(t, val, m) -} - -func TestFailureFuture(t *testing.T) { - e := errors.New("Test error") - f := NewFailureFuture(e) - - val, err := f.Wait() - require.Equal(t, err, e) - require.Equal(t, val, nil) -} - -func TestThenSuccess(t *testing.T) { - m, err := NewMessage() - require.NoError(t, err) - s, err := NewMessage() - require.NoError(t, err) - e := errors.New("Test error") - - f := NewSuccessFuture(m). - Then(func(_ DHCPv6) Future { - return NewSuccessFuture(s) - }, func(_ error) Future { - return NewFailureFuture(e) - }) - - val, err := f.Wait() - require.NoError(t, err) - require.NotEqual(t, val, m) - require.Equal(t, val, s) -} - -func TestThenFailure(t *testing.T) { - m, err := NewMessage() - require.NoError(t, err) - s, err := NewMessage() - require.NoError(t, err) - e := errors.New("Test error") - e2 := errors.New("Test error 2") - - f := NewFailureFuture(e). - Then(func(_ DHCPv6) Future { - return NewSuccessFuture(s) - }, func(_ error) Future { - return NewFailureFuture(e2) - }) - - val, err := f.Wait() - require.Error(t, err) - require.NotEqual(t, val, m) - require.NotEqual(t, val, s) - require.NotEqual(t, err, e) - require.Equal(t, err, e2) -} - -func TestOnSuccess(t *testing.T) { - m, err := NewMessage() - require.NoError(t, err) - s, err := NewMessage() - require.NoError(t, err) - - f := NewSuccessFuture(m). - OnSuccess(func(_ DHCPv6) Future { - return NewSuccessFuture(s) - }) - - val, err := f.Wait() - require.NoError(t, err) - require.NotEqual(t, val, m) - require.Equal(t, val, s) -} - -func TestOnSuccessForFailureFuture(t *testing.T) { - m, err := NewMessage() - require.NoError(t, err) - e := errors.New("Test error") - - f := NewFailureFuture(e). - OnSuccess(func(_ DHCPv6) Future { - return NewSuccessFuture(m) - }) - - val, err := f.Wait() - require.Error(t, err) - require.Equal(t, err, e) - require.NotEqual(t, val, m) -} - -func TestOnFailure(t *testing.T) { - m, err := NewMessage() - require.NoError(t, err) - s, err := NewMessage() - require.NoError(t, err) - e := errors.New("Test error") - - f := NewFailureFuture(e). - OnFailure(func(_ error) Future { - return NewSuccessFuture(s) - }) - - val, err := f.Wait() - require.NoError(t, err) - require.NotEqual(t, val, m) - require.Equal(t, val, s) -} - -func TestOnFailureForSuccessFuture(t *testing.T) { - m, err := NewMessage() - require.NoError(t, err) - s, err := NewMessage() - require.NoError(t, err) - - f := NewSuccessFuture(m). - OnFailure(func(_ error) Future { - return NewSuccessFuture(s) - }) - - val, err := f.Wait() - require.NoError(t, err) - require.NotEqual(t, val, s) - require.Equal(t, val, m) -} - -func TestWaitTimeout(t *testing.T) { - m, err := NewMessage() - require.NoError(t, err) - s, err := NewMessage() - require.NoError(t, err) - f := NewSuccessFuture(m).OnSuccess(func(_ DHCPv6) Future { - time.Sleep(1 * time.Second) - return NewSuccessFuture(s) - }) - val, err := f.WaitTimeout(50 * time.Millisecond) - require.Error(t, err) - require.Equal(t, err.Error(), "Timed out") - require.NotEqual(t, val, m) - require.NotEqual(t, val, s) -} diff --git a/dhcpv6/modifiers.go b/dhcpv6/modifiers.go index 32d11ac..6cd66db 100644 --- a/dhcpv6/modifiers.go +++ b/dhcpv6/modifiers.go @@ -2,6 +2,8 @@ package dhcpv6 import ( "log" + + "github.com/insomniacslk/dhcp/iana" ) // WithClientID adds a client ID option to a DHCPv6 packet @@ -53,9 +55,9 @@ func WithUserClass(uc []byte) Modifier { } // WithArchType adds an arch type option to the packet -func WithArchType(at ArchType) Modifier { +func WithArchType(at iana.ArchType) Modifier { return func(d DHCPv6) DHCPv6 { - ao := OptClientArchType{ArchType: at} + ao := OptClientArchType{ArchTypes: []iana.ArchType{at}} d.AddOption(&ao) return d } diff --git a/dhcpv6/option_archtype.go b/dhcpv6/option_archtype.go index a1b4a9b..231eddd 100644 --- a/dhcpv6/option_archtype.go +++ b/dhcpv6/option_archtype.go @@ -6,42 +6,14 @@ package dhcpv6 import ( "encoding/binary" "fmt" -) - -//ArchType encodes an architecture type in an uint16 -type ArchType uint16 + "strings" -// see rfc4578 -const ( - INTEL_X86PC ArchType = 0 - NEC_PC98 ArchType = 1 - EFI_ITANIUM ArchType = 2 - DEC_ALPHA ArchType = 3 - ARC_X86 ArchType = 4 - INTEL_LEAN_CLIENT ArchType = 5 - EFI_IA32 ArchType = 6 - EFI_BC ArchType = 7 - EFI_XSCALE ArchType = 8 - EFI_X86_64 ArchType = 9 + "github.com/insomniacslk/dhcp/iana" ) -// ArchTypeToStringMap maps an ArchType to a mnemonic name -var ArchTypeToStringMap = map[ArchType]string{ - INTEL_X86PC: "Intel x86PC", - NEC_PC98: "NEC/PC98", - EFI_ITANIUM: "EFI Itanium", - DEC_ALPHA: "DEC Alpha", - ARC_X86: "Arc x86", - INTEL_LEAN_CLIENT: "Intel Lean Client", - EFI_IA32: "EFI IA32", - EFI_BC: "EFI BC", - EFI_XSCALE: "EFI Xscale", - EFI_X86_64: "EFI x86-64", -} - // OptClientArchType represents an option CLIENT_ARCH_TYPE type OptClientArchType struct { - ArchType ArchType + ArchTypes []iana.ArchType } func (op *OptClientArchType) Code() OptionCode { @@ -49,23 +21,28 @@ func (op *OptClientArchType) Code() OptionCode { } func (op *OptClientArchType) ToBytes() []byte { - buf := make([]byte, 6) + buf := make([]byte, 4) binary.BigEndian.PutUint16(buf[0:2], uint16(OptionClientArchType)) binary.BigEndian.PutUint16(buf[2:4], uint16(op.Length())) - binary.BigEndian.PutUint16(buf[4:6], uint16(op.ArchType)) + u16 := make([]byte, 2) + for _, at := range op.ArchTypes { + binary.BigEndian.PutUint16(u16, uint16(at)) + buf = append(buf, u16...) + } return buf } func (op *OptClientArchType) Length() int { - return 2 + return 2*len(op.ArchTypes) } func (op *OptClientArchType) String() string { - name, ok := ArchTypeToStringMap[op.ArchType] - if !ok { - name = "Unknown" + atStrings := make([]string, 0) + for _, at := range op.ArchTypes { + name := iana.ArchTypeToString(at) + atStrings = append(atStrings, name) } - return fmt.Sprintf("OptClientArchType{archtype=%v}", name) + return fmt.Sprintf("OptClientArchType{archtype=%v}", strings.Join(atStrings, ", ")) } // ParseOptClientArchType builds an OptClientArchType structure from @@ -73,9 +50,12 @@ func (op *OptClientArchType) String() string { // length bytes. func ParseOptClientArchType(data []byte) (*OptClientArchType, error) { opt := OptClientArchType{} - if len(data) != 2 { - return nil, fmt.Errorf("Invalid arch type data length. Expected 2 bytes, got %v", len(data)) + if len(data) == 0 || len(data)%2 != 0 { + return nil, fmt.Errorf("Invalid arch type data length. Expected multiple of 2 larger than 2, got %v", len(data)) + } + for idx := 0; idx < len(data); idx += 2 { + b := data[idx : idx+2] + opt.ArchTypes = append(opt.ArchTypes, iana.ArchType(binary.BigEndian.Uint16(b))) } - opt.ArchType = ArchType(binary.BigEndian.Uint16(data)) return &opt, nil } diff --git a/dhcpv6/option_archtype_test.go b/dhcpv6/option_archtype_test.go index 748c8c5..1848e55 100644 --- a/dhcpv6/option_archtype_test.go +++ b/dhcpv6/option_archtype_test.go @@ -3,6 +3,7 @@ package dhcpv6 import ( "testing" + "github.com/insomniacslk/dhcp/iana" "github.com/stretchr/testify/require" ) @@ -12,7 +13,7 @@ func TestParseOptClientArchType(t *testing.T) { } opt, err := ParseOptClientArchType(data) require.NoError(t, err) - require.Equal(t, opt.ArchType, EFI_IA32) + require.Equal(t, opt.ArchTypes[0], iana.EFI_IA32) } func TestParseOptClientArchTypeInvalid(t *testing.T) { @@ -37,7 +38,7 @@ func TestOptClientArchTypeParseAndToBytes(t *testing.T) { func TestOptClientArchType(t *testing.T) { opt := OptClientArchType{ - ArchType: EFI_ITANIUM, + ArchTypes: []iana.ArchType{iana.EFI_ITANIUM}, } require.Equal(t, opt.Length(), 2) require.Equal(t, opt.Code(), OptionClientArchType) diff --git a/dhcpv6/utils.go b/dhcpv6/utils.go deleted file mode 100644 index 1681661..0000000 --- a/dhcpv6/utils.go +++ /dev/null @@ -1,77 +0,0 @@ -package dhcpv6 - -import ( - "errors" - "strings" -) - -// IsNetboot function takes a DHCPv6 message and returns true if the machine -// is trying to netboot. It checks if "boot file" is one of the requested -// options, which is useful for SOLICIT/REQUEST packet types, it also checks -// if the "boot file" option is included in the packet, which is useful for -// ADVERTISE/REPLY packet. -func IsNetboot(msg DHCPv6) bool { - for _, optoro := range msg.GetOption(OptionORO) { - for _, o := range optoro.(*OptRequestedOption).RequestedOptions() { - if o == OptionBootfileURL { - return true - } - } - } - if optbf := msg.GetOneOption(OptionBootfileURL); optbf != nil { - return true - } - return false -} - -// IsUsingUEFI function takes a DHCPv6 message and returns true if -// the machine trying to netboot is using UEFI of false if it is not. -func IsUsingUEFI(msg DHCPv6) bool { - // RFC 4578 says: - // As of the writing of this document, the following pre-boot - // architecture types have been requested. - // Type Architecture Name - // ---- ----------------- - // 0 Intel x86PC - // 1 NEC/PC98 - // 2 EFI Itanium - // 3 DEC Alpha - // 4 Arc x86 - // 5 Intel Lean Client - // 6 EFI IA32 - // 7 EFI BC - // 8 EFI Xscale - // 9 EFI x86-64 - if opt := msg.GetOneOption(OptionClientArchType); opt != nil { - optat := opt.(*OptClientArchType) - // TODO investigate if other types are appropriate - if optat.ArchType == EFI_BC || optat.ArchType == EFI_X86_64 { - return true - } - } - if opt := msg.GetOneOption(OptionUserClass); opt != nil { - optuc := opt.(*OptUserClass) - for _, uc := range optuc.UserClasses { - if strings.Contains(string(uc), "EFI") { - return true - } - } - } - return false -} - -// GetTransactionID returns a transactionID of a message or its inner message -// in case of relay -func GetTransactionID(packet DHCPv6) (uint32, error) { - if message, ok := packet.(*DHCPv6Message); ok { - return message.TransactionID(), nil - } - if relay, ok := packet.(*DHCPv6Relay); ok { - message, err := relay.GetInnerMessage() - if err != nil { - return 0, err - } - return GetTransactionID(message) - } - return 0, errors.New("Invalid DHCPv6 packet") -} diff --git a/dhcpv6/utils_test.go b/dhcpv6/utils_test.go deleted file mode 100644 index f3b53f0..0000000 --- a/dhcpv6/utils_test.go +++ /dev/null @@ -1,69 +0,0 @@ -package dhcpv6 - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestIsNetboot(t *testing.T) { - msg1 := DHCPv6Message{} - require.False(t, IsNetboot(&msg1)) - - msg2 := DHCPv6Message{} - optro := OptRequestedOption{} - optro.AddRequestedOption(OptionBootfileURL) - msg2.AddOption(&optro) - require.True(t, IsNetboot(&msg2)) - - msg3 := DHCPv6Message{} - optbf := OptBootFileURL{} - msg3.AddOption(&optbf) - require.True(t, IsNetboot(&msg3)) -} - -func TestIsUsingUEFIArchTypeTrue(t *testing.T) { - msg := DHCPv6Message{} - opt := OptClientArchType{ArchType: EFI_BC} - msg.AddOption(&opt) - require.True(t, IsUsingUEFI(&msg)) -} - -func TestIsUsingUEFIArchTypeFalse(t *testing.T) { - msg := DHCPv6Message{} - opt := OptClientArchType{ArchType: INTEL_X86PC} - msg.AddOption(&opt) - require.False(t, IsUsingUEFI(&msg)) -} - -func TestIsUsingUEFIUserClassTrue(t *testing.T) { - msg := DHCPv6Message{} - opt := OptUserClass{UserClasses: [][]byte{[]byte("ipxeUEFI")}} - msg.AddOption(&opt) - require.True(t, IsUsingUEFI(&msg)) -} - -func TestIsUsingUEFIUserClassFalse(t *testing.T) { - msg := DHCPv6Message{} - opt := OptUserClass{UserClasses: [][]byte{[]byte("ipxeLegacy")}} - msg.AddOption(&opt) - require.False(t, IsUsingUEFI(&msg)) -} - -func TestGetTransactionIDMessage(t *testing.T) { - message, err := NewMessage() - require.NoError(t, err) - transactionID, err := GetTransactionID(message) - require.NoError(t, err) - require.Equal(t, transactionID, message.(*DHCPv6Message).TransactionID()) -} - -func TestGetTransactionIDRelay(t *testing.T) { - message, err := NewMessage() - require.NoError(t, err) - relay, err := EncapsulateRelay(message, MessageTypeRelayForward, nil, nil) - require.NoError(t, err) - transactionID, err := GetTransactionID(relay) - require.NoError(t, err) - require.Equal(t, transactionID, message.(*DHCPv6Message).TransactionID()) -} |