diff options
author | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2017-10-08 22:03:32 +0200 |
---|---|---|
committer | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2017-10-08 22:03:32 +0200 |
commit | a72b0f7ae5dda27d839bb317b7c01d11b215e77a (patch) | |
tree | bd70e6fc71574b300e6e928b7887d69c7cf4ddef /src/conn_linux.go | |
parent | 2d856045a0dbfc15d38d738e2a9d159ba2a49a47 (diff) |
Added new UDPBind interface
Diffstat (limited to 'src/conn_linux.go')
-rw-r--r-- | src/conn_linux.go | 271 |
1 files changed, 176 insertions, 95 deletions
diff --git a/src/conn_linux.go b/src/conn_linux.go index 034fb8b..8942b03 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -14,35 +14,158 @@ import ( "unsafe" ) -import "fmt" - /* Supports source address caching * * Currently there is no way to achieve this within the net package: * See e.g. https://github.com/golang/go/issues/17930 - * So this code is platform dependent. - * - * It is important that the endpoint is only updated after the packet content has been authenticated! + * So this code is remains platform dependent. */ type Endpoint struct { - // source (selected based on dst type) - // (could use RawSockaddrAny and unsafe) - // TODO: Merge - src6 unix.RawSockaddrInet6 - src4 unix.RawSockaddrInet4 - src4if int32 - - dst unix.RawSockaddrAny + src unix.RawSockaddrInet6 + dst unix.RawSockaddrInet6 +} + +type IPv4Source struct { + src unix.RawSockaddrInet4 + Ifindex int32 } -type Socket int +type Bind struct { + sock4 int + sock6 int +} -/* Returns a byte representation of the source field(s) - * for use in "under load" cookie computations. - */ -func (endpoint *Endpoint) Source() []byte { - return nil +func CreateUDPBind(port uint16) (UDPBind, uint16, error) { + var err error + var bind Bind + + bind.sock6, port, err = create6(port) + if err != nil { + return nil, port, err + } + + bind.sock4, port, err = create4(port) + if err != nil { + unix.Close(bind.sock6) + } + return &bind, port, err +} + +func (bind *Bind) SetMark(value uint32) error { + err := unix.SetsockoptInt( + bind.sock6, + unix.SOL_SOCKET, + unix.SO_MARK, + int(value), + ) + + if err != nil { + return err + } + + return unix.SetsockoptInt( + bind.sock4, + unix.SOL_SOCKET, + unix.SO_MARK, + int(value), + ) +} + +func (bind *Bind) Close() error { + err1 := unix.Close(bind.sock6) + err2 := unix.Close(bind.sock4) + if err1 != nil { + return err1 + } + return err2 +} + +func (bind *Bind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) { + return receive6( + bind.sock6, + buff, + end, + ) +} + +func (bind *Bind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) { + return receive4( + bind.sock4, + buff, + end, + ) +} + +func (bind *Bind) Send(buff []byte, end *Endpoint) error { + switch end.src.Family { + case unix.AF_INET6: + return send6(bind.sock6, end, buff) + case unix.AF_INET: + return send4(bind.sock4, end, buff) + default: + return errors.New("Unknown address family of source") + } +} + +func sockaddrToString(addr unix.RawSockaddrInet6) string { + var udpAddr net.UDPAddr + + switch addr.Family { + case unix.AF_INET6: + udpAddr.Port = int(addr.Port) + udpAddr.IP = addr.Addr[:] + return udpAddr.String() + + case unix.AF_INET: + ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr)) + udpAddr.Port = int(ptr.Port) + udpAddr.IP = net.IPv4( + ptr.Addr[0], + ptr.Addr[1], + ptr.Addr[2], + ptr.Addr[3], + ) + return udpAddr.String() + + default: + return "<unknown address family>" + } +} + +func (end *Endpoint) DestinationIP() net.IP { + switch end.dst.Family { + case unix.AF_INET6: + return end.dst.Addr[:] + case unix.AF_INET: + ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) + return net.IPv4( + ptr.Addr[0], + ptr.Addr[1], + ptr.Addr[2], + ptr.Addr[3], + ) + default: + return nil + } +} + +func (end *Endpoint) SourceToBytes() []byte { + ptr := unsafe.Pointer(&end.src) + arr := (*[unix.SizeofSockaddrInet6]byte)(ptr) + return arr[:] +} + +func (end *Endpoint) SourceToString() string { + return sockaddrToString(end.src) +} + +func (end *Endpoint) DestinationToString() string { + return sockaddrToString(end.dst) +} + +func (end *Endpoint) ClearSrc() { + end.src = unix.RawSockaddrInet6{} } func zoneToUint32(zone string) (uint32, error) { @@ -56,7 +179,7 @@ func zoneToUint32(zone string) (uint32, error) { return uint32(n), err } -func CreateIPv4Socket(port uint16) (Socket, uint16, error) { +func create4(port uint16) (int, uint16, error) { // create socket @@ -100,18 +223,10 @@ func CreateIPv4Socket(port uint16) (Socket, uint16, error) { unix.Close(fd) } - return Socket(fd), uint16(addr.Port), err + return fd, uint16(addr.Port), err } -func CloseIPv4Socket(sock Socket) error { - return unix.Close(int(sock)) -} - -func CloseIPv6Socket(sock Socket) error { - return unix.Close(int(sock)) -} - -func CreateIPv6Socket(port uint16) (Socket, uint16, error) { +func create6(port uint16) (int, uint16, error) { // create socket @@ -166,13 +281,7 @@ func CreateIPv6Socket(port uint16) (Socket, uint16, error) { unix.Close(fd) } - return Socket(fd), uint16(addr.Port), err -} - -func (end *Endpoint) ClearSrc() { - end.src4if = 0 - end.src4 = unix.RawSockaddrInet4{} - end.src6 = unix.RawSockaddrInet6{} + return fd, uint16(addr.Port), err } func (end *Endpoint) Set(s string) error { @@ -187,23 +296,23 @@ func (end *Endpoint) Set(s string) error { if err != nil { return err } - ptr := (*unix.RawSockaddrInet6)(unsafe.Pointer(&end.dst)) - ptr.Family = unix.AF_INET6 - ptr.Port = uint16(addr.Port) - ptr.Flowinfo = 0 - ptr.Scope_id = zone - copy(ptr.Addr[:], ipv6[:]) + dst := &end.dst + dst.Family = unix.AF_INET6 + dst.Port = uint16(addr.Port) + dst.Flowinfo = 0 + dst.Scope_id = zone + copy(dst.Addr[:], ipv6[:]) end.ClearSrc() return nil } ipv4 := addr.IP.To4() if ipv4 != nil { - ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) - ptr.Family = unix.AF_INET - ptr.Port = uint16(addr.Port) - ptr.Zero = [8]byte{} - copy(ptr.Addr[:], ipv4) + dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) + dst.Family = unix.AF_INET + dst.Port = uint16(addr.Port) + dst.Zero = [8]byte{} + copy(dst.Addr[:], ipv4) end.ClearSrc() return nil } @@ -211,7 +320,7 @@ func (end *Endpoint) Set(s string) error { return errors.New("Failed to recognize IP address format") } -func send6(sock uintptr, end *Endpoint, buff []byte) error { +func send6(sock int, end *Endpoint, buff []byte) error { // construct message header @@ -229,8 +338,8 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error { Len: unix.SizeofInet6Pktinfo, }, unix.Inet6Pktinfo{ - Addr: end.src6.Addr, - Ifindex: end.src6.Scope_id, + Addr: end.src.Addr, + Ifindex: end.src.Scope_id, }, } @@ -248,7 +357,7 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error { _, _, errno := unix.Syscall( unix.SYS_SENDMSG, - sock, + uintptr(sock), uintptr(unsafe.Pointer(&msghdr)), 0, ) @@ -258,7 +367,7 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error { return errno } -func send4(sock uintptr, end *Endpoint, buff []byte) error { +func send4(sock int, end *Endpoint, buff []byte) error { // construct message header @@ -266,6 +375,8 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error { iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) iovec.SetLen(len(buff)) + src4 := (*IPv4Source)(unsafe.Pointer(&end.src)) + cmsg := struct { cmsghdr unix.Cmsghdr pktinfo unix.Inet4Pktinfo @@ -276,8 +387,8 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error { Len: unix.SizeofInet4Pktinfo, }, unix.Inet4Pktinfo{ - Spec_dst: end.src4.Addr, - Ifindex: end.src4if, + Spec_dst: src4.src.Addr, + Ifindex: src4.Ifindex, }, } @@ -295,7 +406,7 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error { _, _, errno := unix.Syscall( unix.SYS_SENDMSG, - sock, + uintptr(sock), uintptr(unsafe.Pointer(&msghdr)), 0, ) @@ -305,28 +416,7 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error { return errno } -func (end *Endpoint) Send(c *net.UDPConn, buff []byte) error { - - // extract underlying file descriptor - - file, err := c.File() - if err != nil { - return err - } - sock := file.Fd() - - // send depending on address family of dst - - family := *((*uint16)(unsafe.Pointer(&end.dst))) - if family == unix.AF_INET { - return send4(sock, end, buff) - } else if family == unix.AF_INET6 { - return send6(sock, end, buff) - } - return errors.New("Unknown address family of source") -} - -func (end *Endpoint) ReceiveIPv4(sock Socket, buff []byte) (int, error) { +func receive4(sock int, buff []byte, end *Endpoint) (int, error) { // contruct message header @@ -360,22 +450,21 @@ func (end *Endpoint) ReceiveIPv4(sock Socket, buff []byte) (int, error) { return 0, errno } - fmt.Println(msghdr) - fmt.Println(cmsg) - // update source cache if cmsg.cmsghdr.Level == unix.IPPROTO_IP && cmsg.cmsghdr.Type == unix.IP_PKTINFO && cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { - end.src4.Addr = cmsg.pktinfo.Spec_dst - end.src4if = cmsg.pktinfo.Ifindex + src4 := (*IPv4Source)(unsafe.Pointer(&end.src)) + src4.src.Family = unix.AF_INET + src4.src.Addr = cmsg.pktinfo.Spec_dst + src4.Ifindex = cmsg.pktinfo.Ifindex } return int(size), nil } -func (end *Endpoint) ReceiveIPv6(sock Socket, buff []byte) (int, error) { +func receive6(sock int, buff []byte, end *Endpoint) (int, error) { // contruct message header @@ -414,18 +503,10 @@ func (end *Endpoint) ReceiveIPv6(sock Socket, buff []byte) (int, error) { if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { - end.src6.Addr = cmsg.pktinfo.Addr - end.src6.Scope_id = cmsg.pktinfo.Ifindex + end.src.Family = unix.AF_INET6 + end.src.Addr = cmsg.pktinfo.Addr + end.src.Scope_id = cmsg.pktinfo.Ifindex } return int(size), nil } - -func SetMark(sock Socket, value uint32) error { - return unix.SetsockoptInt( - int(sock), - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) -} |