diff options
-rw-r--r-- | dhcpv4/client.go | 97 | ||||
-rw-r--r-- | netboot/netconf.go | 4 |
2 files changed, 78 insertions, 23 deletions
diff --git a/dhcpv4/client.go b/dhcpv4/client.go index 61722f4..3370ecf 100644 --- a/dhcpv4/client.go +++ b/dhcpv4/client.go @@ -85,7 +85,7 @@ func MakeRawUDPPacket(payload []byte, serverAddr, clientAddr net.UDPAddr) ([]byt } // makeRawSocket creates a socket that can be passed to unix.Sendto. -func makeRawSocket() (int, error) { +func makeRawSocket(ifname string) (int, error) { fd, err := unix.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_RAW) if err != nil { return fd, err @@ -98,13 +98,17 @@ func makeRawSocket() (int, error) { if err != nil { return fd, err } + err = BindToInterface(fd, ifname) + if err != nil { + return fd, err + } return fd, nil } // MakeBroadcastSocket creates a socket that can be passed to unix.Sendto // that will send packets out to the broadcast address. func MakeBroadcastSocket(ifname string) (int, error) { - fd, err := makeRawSocket() + fd, err := makeRawSocket(ifname) if err != nil { return fd, err } @@ -121,25 +125,27 @@ func MakeListeningSocket(ifname string) (int, error) { return makeListeningSocketWithCustomPort(ifname, ClientPort) } +func htons(v uint16) uint16 { + var tmp [2]byte + binary.BigEndian.PutUint16(tmp[:], v) + return binary.LittleEndian.Uint16(tmp[:]) +} + func makeListeningSocketWithCustomPort(ifname string, port int) (int, error) { - fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP) + fd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_DGRAM, int(htons(unix.ETH_P_IP))) if err != nil { return fd, err } - err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1) + iface, err := net.InterfaceByName(ifname) if err != nil { return fd, err } - var addr [4]byte - copy(addr[:], net.IPv4zero.To4()) - if err = unix.Bind(fd, &unix.SockaddrInet4{Port: port, Addr: addr}); err != nil { - return fd, err + llAddr := unix.SockaddrLinklayer{ + Ifindex: iface.Index, + Protocol: htons(unix.ETH_P_IP), } - err = BindToInterface(fd, ifname) - if err != nil { - return fd, err - } - return fd, nil + err = unix.Bind(fd, &llAddr) + return fd, err } func toUDPAddr(addr net.Addr, defaultAddr *net.UDPAddr) (*net.UDPAddr, error) { @@ -202,7 +208,7 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4, modifiers ...Modifier if raddr.IP.Equal(net.IPv4bcast) { sfd, err = MakeBroadcastSocket(ifname) } else { - sfd, err = makeRawSocket() + sfd, err = makeRawSocket(ifname) } if err != nil { return conversation, err @@ -212,6 +218,18 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4, modifiers ...Modifier return conversation, err } + defer func() { + // close the sockets + if err := unix.Close(sfd); err != nil { + log.Printf("unix.Close(sendFd) failed: %v", err) + } + if sfd != rfd { + if err := unix.Close(rfd); err != nil { + log.Printf("unix.Close(recvFd) failed: %v", err) + } + } + }() + // Discover if discover == nil { discover, err = NewDiscoveryForInterface(ifname) @@ -244,6 +262,7 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4, modifiers ...Modifier return conversation, err } conversation = append(conversation, ack) + return conversation, nil } @@ -267,30 +286,59 @@ func (c *Client) sendReceive(sendFd, recvFd int, packet *DHCPv4, messageType Mes // Create a goroutine to perform the blocking send, and time it out after // a certain amount of time. var ( - destination [4]byte + destination [net.IPv4len]byte response *DHCPv4 ) copy(destination[:], raddr.IP.To4()) remoteAddr := unix.SockaddrInet4{Port: laddr.Port, Addr: destination} recvErrors := make(chan error, 1) go func(errs chan<- error) { - conn, innerErr := net.FileConn(os.NewFile(uintptr(recvFd), "")) - if innerErr != nil { + // set read timeout + timeout := unix.NsecToTimeval(c.ReadTimeout.Nanoseconds()) + if innerErr := unix.SetsockoptTimeval(recvFd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &timeout); innerErr != nil { errs <- innerErr return } - defer conn.Close() - conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) - for { buf := make([]byte, MaxUDPReceivedPacketSize) - n, _, _, _, innerErr := conn.(*net.UDPConn).ReadMsgUDP(buf, []byte{}) + n, _, innerErr := unix.Recvfrom(recvFd, buf, 0) if innerErr != nil { errs <- innerErr return } - response, innerErr = FromBytes(buf[:n]) + var iph ipv4.Header + if err := iph.Parse(buf[:n]); err != nil { + // skip non-IP data + continue + } + if iph.Protocol != 17 { + // skip non-UDP packets + continue + } + udph := buf[iph.Len:n] + // check source and destination ports + srcPort := int(binary.BigEndian.Uint16(udph[0:2])) + expectedSrcPort := ServerPort + if c.RemoteAddr != nil { + expectedSrcPort = c.RemoteAddr.(*net.UDPAddr).Port + } + if srcPort != expectedSrcPort { + continue + } + dstPort := int(binary.BigEndian.Uint16(udph[2:4])) + expectedDstPort := ClientPort + if c.RemoteAddr != nil { + expectedDstPort = c.LocalAddr.(*net.UDPAddr).Port + } + if dstPort != expectedDstPort { + continue + } + // UDP checksum is not checked + pLen := int(binary.BigEndian.Uint16(udph[4:6])) + payload := buf[iph.Len+8 : iph.Len+8+pLen] + + response, innerErr = FromBytes(payload) if innerErr != nil { errs <- innerErr return @@ -315,12 +363,17 @@ func (c *Client) sendReceive(sendFd, recvFd int, packet *DHCPv4, messageType Mes } recvErrors <- nil }(recvErrors) + + // send the request while the goroutine waits for replies if err = unix.Sendto(sendFd, packetBytes, 0, &remoteAddr); err != nil { return nil, err } select { case err = <-recvErrors: + if err == unix.EAGAIN { + return nil, errors.New("timed out while listening for replies") + } if err != nil { return nil, err } diff --git a/netboot/netconf.go b/netboot/netconf.go index bf0274e..3cc5232 100644 --- a/netboot/netconf.go +++ b/netboot/netconf.go @@ -208,7 +208,9 @@ func ConfigureInterface(ifname string, netconf *NetConf) error { for _, ns := range netconf.DNSServers { resolvconf += fmt.Sprintf("nameserver %s\n", ns) } - resolvconf += fmt.Sprintf("search %s\n", strings.Join(netconf.DNSSearchList, " ")) + if len(netconf.DNSSearchList) > 0 { + resolvconf += fmt.Sprintf("search %s\n", strings.Join(netconf.DNSSearchList, " ")) + } if err = ioutil.WriteFile("/etc/resolv.conf", []byte(resolvconf), 0644); err != nil { return fmt.Errorf("could not write resolv.conf file %v", err) } |