summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--dhcpv4/client.go97
-rw-r--r--netboot/netconf.go4
2 files changed, 78 insertions, 23 deletions
diff --git a/dhcpv4/client.go b/dhcpv4/client.go
index 61722f4..3370ecf 100644
--- a/dhcpv4/client.go
+++ b/dhcpv4/client.go
@@ -85,7 +85,7 @@ func MakeRawUDPPacket(payload []byte, serverAddr, clientAddr net.UDPAddr) ([]byt
}
// makeRawSocket creates a socket that can be passed to unix.Sendto.
-func makeRawSocket() (int, error) {
+func makeRawSocket(ifname string) (int, error) {
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_RAW)
if err != nil {
return fd, err
@@ -98,13 +98,17 @@ func makeRawSocket() (int, error) {
if err != nil {
return fd, err
}
+ err = BindToInterface(fd, ifname)
+ if err != nil {
+ return fd, err
+ }
return fd, nil
}
// MakeBroadcastSocket creates a socket that can be passed to unix.Sendto
// that will send packets out to the broadcast address.
func MakeBroadcastSocket(ifname string) (int, error) {
- fd, err := makeRawSocket()
+ fd, err := makeRawSocket(ifname)
if err != nil {
return fd, err
}
@@ -121,25 +125,27 @@ func MakeListeningSocket(ifname string) (int, error) {
return makeListeningSocketWithCustomPort(ifname, ClientPort)
}
+func htons(v uint16) uint16 {
+ var tmp [2]byte
+ binary.BigEndian.PutUint16(tmp[:], v)
+ return binary.LittleEndian.Uint16(tmp[:])
+}
+
func makeListeningSocketWithCustomPort(ifname string, port int) (int, error) {
- fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
+ fd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_DGRAM, int(htons(unix.ETH_P_IP)))
if err != nil {
return fd, err
}
- err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1)
+ iface, err := net.InterfaceByName(ifname)
if err != nil {
return fd, err
}
- var addr [4]byte
- copy(addr[:], net.IPv4zero.To4())
- if err = unix.Bind(fd, &unix.SockaddrInet4{Port: port, Addr: addr}); err != nil {
- return fd, err
+ llAddr := unix.SockaddrLinklayer{
+ Ifindex: iface.Index,
+ Protocol: htons(unix.ETH_P_IP),
}
- err = BindToInterface(fd, ifname)
- if err != nil {
- return fd, err
- }
- return fd, nil
+ err = unix.Bind(fd, &llAddr)
+ return fd, err
}
func toUDPAddr(addr net.Addr, defaultAddr *net.UDPAddr) (*net.UDPAddr, error) {
@@ -202,7 +208,7 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4, modifiers ...Modifier
if raddr.IP.Equal(net.IPv4bcast) {
sfd, err = MakeBroadcastSocket(ifname)
} else {
- sfd, err = makeRawSocket()
+ sfd, err = makeRawSocket(ifname)
}
if err != nil {
return conversation, err
@@ -212,6 +218,18 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4, modifiers ...Modifier
return conversation, err
}
+ defer func() {
+ // close the sockets
+ if err := unix.Close(sfd); err != nil {
+ log.Printf("unix.Close(sendFd) failed: %v", err)
+ }
+ if sfd != rfd {
+ if err := unix.Close(rfd); err != nil {
+ log.Printf("unix.Close(recvFd) failed: %v", err)
+ }
+ }
+ }()
+
// Discover
if discover == nil {
discover, err = NewDiscoveryForInterface(ifname)
@@ -244,6 +262,7 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4, modifiers ...Modifier
return conversation, err
}
conversation = append(conversation, ack)
+
return conversation, nil
}
@@ -267,30 +286,59 @@ func (c *Client) sendReceive(sendFd, recvFd int, packet *DHCPv4, messageType Mes
// Create a goroutine to perform the blocking send, and time it out after
// a certain amount of time.
var (
- destination [4]byte
+ destination [net.IPv4len]byte
response *DHCPv4
)
copy(destination[:], raddr.IP.To4())
remoteAddr := unix.SockaddrInet4{Port: laddr.Port, Addr: destination}
recvErrors := make(chan error, 1)
go func(errs chan<- error) {
- conn, innerErr := net.FileConn(os.NewFile(uintptr(recvFd), ""))
- if innerErr != nil {
+ // set read timeout
+ timeout := unix.NsecToTimeval(c.ReadTimeout.Nanoseconds())
+ if innerErr := unix.SetsockoptTimeval(recvFd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &timeout); innerErr != nil {
errs <- innerErr
return
}
- defer conn.Close()
- conn.SetReadDeadline(time.Now().Add(c.ReadTimeout))
-
for {
buf := make([]byte, MaxUDPReceivedPacketSize)
- n, _, _, _, innerErr := conn.(*net.UDPConn).ReadMsgUDP(buf, []byte{})
+ n, _, innerErr := unix.Recvfrom(recvFd, buf, 0)
if innerErr != nil {
errs <- innerErr
return
}
- response, innerErr = FromBytes(buf[:n])
+ var iph ipv4.Header
+ if err := iph.Parse(buf[:n]); err != nil {
+ // skip non-IP data
+ continue
+ }
+ if iph.Protocol != 17 {
+ // skip non-UDP packets
+ continue
+ }
+ udph := buf[iph.Len:n]
+ // check source and destination ports
+ srcPort := int(binary.BigEndian.Uint16(udph[0:2]))
+ expectedSrcPort := ServerPort
+ if c.RemoteAddr != nil {
+ expectedSrcPort = c.RemoteAddr.(*net.UDPAddr).Port
+ }
+ if srcPort != expectedSrcPort {
+ continue
+ }
+ dstPort := int(binary.BigEndian.Uint16(udph[2:4]))
+ expectedDstPort := ClientPort
+ if c.RemoteAddr != nil {
+ expectedDstPort = c.LocalAddr.(*net.UDPAddr).Port
+ }
+ if dstPort != expectedDstPort {
+ continue
+ }
+ // UDP checksum is not checked
+ pLen := int(binary.BigEndian.Uint16(udph[4:6]))
+ payload := buf[iph.Len+8 : iph.Len+8+pLen]
+
+ response, innerErr = FromBytes(payload)
if innerErr != nil {
errs <- innerErr
return
@@ -315,12 +363,17 @@ func (c *Client) sendReceive(sendFd, recvFd int, packet *DHCPv4, messageType Mes
}
recvErrors <- nil
}(recvErrors)
+
+ // send the request while the goroutine waits for replies
if err = unix.Sendto(sendFd, packetBytes, 0, &remoteAddr); err != nil {
return nil, err
}
select {
case err = <-recvErrors:
+ if err == unix.EAGAIN {
+ return nil, errors.New("timed out while listening for replies")
+ }
if err != nil {
return nil, err
}
diff --git a/netboot/netconf.go b/netboot/netconf.go
index bf0274e..3cc5232 100644
--- a/netboot/netconf.go
+++ b/netboot/netconf.go
@@ -208,7 +208,9 @@ func ConfigureInterface(ifname string, netconf *NetConf) error {
for _, ns := range netconf.DNSServers {
resolvconf += fmt.Sprintf("nameserver %s\n", ns)
}
- resolvconf += fmt.Sprintf("search %s\n", strings.Join(netconf.DNSSearchList, " "))
+ if len(netconf.DNSSearchList) > 0 {
+ resolvconf += fmt.Sprintf("search %s\n", strings.Join(netconf.DNSSearchList, " "))
+ }
if err = ioutil.WriteFile("/etc/resolv.conf", []byte(resolvconf), 0644); err != nil {
return fmt.Errorf("could not write resolv.conf file %v", err)
}