diff options
Diffstat (limited to 'conn')
-rw-r--r-- | conn/bind_linux.go | 50 | ||||
-rw-r--r-- | conn/bind_std.go | 19 | ||||
-rw-r--r-- | conn/bind_windows.go | 19 | ||||
-rw-r--r-- | conn/bindtest/bindtest.go | 14 | ||||
-rw-r--r-- | conn/conn.go | 37 |
5 files changed, 50 insertions, 89 deletions
diff --git a/conn/bind_linux.go b/conn/bind_linux.go index 7b970e6..da0670a 100644 --- a/conn/bind_linux.go +++ b/conn/bind_linux.go @@ -14,6 +14,7 @@ import ( "unsafe" "golang.org/x/sys/unix" + "golang.zx2c4.com/go118/netip" ) type ipv4Source struct { @@ -70,32 +71,30 @@ var _ Bind = (*LinuxSocketBind)(nil) func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) { var end LinuxSocketEndpoint - addr, err := parseEndpoint(s) + e, err := netip.ParseAddrPort(s) if err != nil { return nil, err } - ipv4 := addr.IP.To4() - if ipv4 != nil { + if e.Addr().Is4() { dst := end.dst4() end.isV6 = false - dst.Port = addr.Port - copy(dst.Addr[:], ipv4) + dst.Port = int(e.Port()) + dst.Addr = e.Addr().As4() end.ClearSrc() return &end, nil } - ipv6 := addr.IP.To16() - if ipv6 != nil { - zone, err := zoneToUint32(addr.Zone) + if e.Addr().Is6() { + zone, err := zoneToUint32(e.Addr().Zone()) if err != nil { return nil, err } dst := end.dst6() end.isV6 = true - dst.Port = addr.Port + dst.Port = int(e.Port()) dst.ZoneId = zone - copy(dst.Addr[:], ipv6[:]) + dst.Addr = e.Addr().As16() end.ClearSrc() return &end, nil } @@ -266,29 +265,19 @@ func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error { } } -func (end *LinuxSocketEndpoint) SrcIP() net.IP { +func (end *LinuxSocketEndpoint) SrcIP() netip.Addr { if !end.isV6 { - return net.IPv4( - end.src4().Src[0], - end.src4().Src[1], - end.src4().Src[2], - end.src4().Src[3], - ) + return netip.AddrFrom4(end.src4().Src) } else { - return end.src6().src[:] + return netip.AddrFrom16(end.src6().src) } } -func (end *LinuxSocketEndpoint) DstIP() net.IP { +func (end *LinuxSocketEndpoint) DstIP() netip.Addr { if !end.isV6 { - return net.IPv4( - end.dst4().Addr[0], - end.dst4().Addr[1], - end.dst4().Addr[2], - end.dst4().Addr[3], - ) + return netip.AddrFrom4(end.dst4().Addr) } else { - return end.dst6().Addr[:] + return netip.AddrFrom16(end.dst6().Addr) } } @@ -305,14 +294,13 @@ func (end *LinuxSocketEndpoint) SrcToString() string { } func (end *LinuxSocketEndpoint) DstToString() string { - var udpAddr net.UDPAddr - udpAddr.IP = end.DstIP() + var port int if !end.isV6 { - udpAddr.Port = end.dst4().Port + port = end.dst4().Port } else { - udpAddr.Port = end.dst6().Port + port = end.dst6().Port } - return udpAddr.String() + return netip.AddrPortFrom(end.DstIP(), uint16(port)).String() } func (end *LinuxSocketEndpoint) ClearDst() { diff --git a/conn/bind_std.go b/conn/bind_std.go index cb85cfd..a3cbb15 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -10,6 +10,8 @@ import ( "net" "sync" "syscall" + + "golang.zx2c4.com/go118/netip" ) // StdNetBind is meant to be a temporary solution on platforms for which @@ -32,18 +34,23 @@ var _ Bind = (*StdNetBind)(nil) var _ Endpoint = (*StdNetEndpoint)(nil) func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { - addr, err := parseEndpoint(s) - return (*StdNetEndpoint)(addr), err + e, err := netip.ParseAddrPort(s) + return (*StdNetEndpoint)(&net.UDPAddr{ + IP: e.Addr().AsSlice(), + Port: int(e.Port()), + Zone: e.Addr().Zone(), + }), err } func (*StdNetEndpoint) ClearSrc() {} -func (e *StdNetEndpoint) DstIP() net.IP { - return (*net.UDPAddr)(e).IP +func (e *StdNetEndpoint) DstIP() netip.Addr { + a, _ := netip.AddrFromSlice((*net.UDPAddr)(e).IP) + return a } -func (e *StdNetEndpoint) SrcIP() net.IP { - return nil // not supported +func (e *StdNetEndpoint) SrcIP() netip.Addr { + return netip.Addr{} // not supported } func (e *StdNetEndpoint) DstToBytes() []byte { diff --git a/conn/bind_windows.go b/conn/bind_windows.go index 42e06ad..26a3af8 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -15,6 +15,7 @@ import ( "unsafe" "golang.org/x/sys/windows" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/conn/winrio" ) @@ -128,18 +129,18 @@ func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) { func (*WinRingEndpoint) ClearSrc() {} -func (e *WinRingEndpoint) DstIP() net.IP { +func (e *WinRingEndpoint) DstIP() netip.Addr { switch e.family { case windows.AF_INET: - return append([]byte{}, e.data[2:6]...) + return netip.AddrFrom4(*(*[4]byte)(e.data[2:6])) case windows.AF_INET6: - return append([]byte{}, e.data[6:22]...) + return netip.AddrFrom16(*(*[16]byte)(e.data[6:22])) } - return nil + return netip.Addr{} } -func (e *WinRingEndpoint) SrcIP() net.IP { - return nil // not supported +func (e *WinRingEndpoint) SrcIP() netip.Addr { + return netip.Addr{} // not supported } func (e *WinRingEndpoint) DstToBytes() []byte { @@ -161,15 +162,13 @@ func (e *WinRingEndpoint) DstToBytes() []byte { func (e *WinRingEndpoint) DstToString() string { switch e.family { case windows.AF_INET: - addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))} - return addr.String() + netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String() case windows.AF_INET6: var zone string if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 { zone = strconv.FormatUint(uint64(scope), 10) } - addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))} - return addr.String() + return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String() } return "" } diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go index 7d43fb3..6a45896 100644 --- a/conn/bindtest/bindtest.go +++ b/conn/bindtest/bindtest.go @@ -10,8 +10,8 @@ import ( "math/rand" "net" "os" - "strconv" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/conn" ) @@ -61,9 +61,9 @@ func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} } -func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) } +func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } -func (c ChannelEndpoint) SrcIP() net.IP { return nil } +func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} } func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { c.closeSignal = make(chan bool) @@ -119,13 +119,9 @@ func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error { } func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) { - _, port, err := net.SplitHostPort(s) + addr, err := netip.ParseAddrPort(s) if err != nil { return nil, err } - i, err := strconv.ParseUint(port, 10, 16) - if err != nil { - return nil, err - } - return ChannelEndpoint(i), nil + return ChannelEndpoint(addr.Port()), nil } diff --git a/conn/conn.go b/conn/conn.go index 9cce9ad..35fb6b1 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -9,10 +9,11 @@ package conn import ( "errors" "fmt" - "net" "reflect" "runtime" "strings" + + "golang.zx2c4.com/go118/netip" ) // A ReceiveFunc receives a single inbound packet from the network. @@ -68,8 +69,8 @@ type Endpoint interface { SrcToString() string // returns the local source address (ip:port) DstToString() string // returns the destination address (ip:port) DstToBytes() []byte // used for mac2 cookie calculations - DstIP() net.IP - SrcIP() net.IP + DstIP() netip.Addr + SrcIP() netip.Addr } var ( @@ -119,33 +120,3 @@ func (fn ReceiveFunc) PrettyName() string { } return name } - -func parseEndpoint(s string) (*net.UDPAddr, error) { - // ensure that the host is an IP address - - host, _, err := net.SplitHostPort(s) - if err != nil { - return nil, err - } - if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 { - // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just - // trying to make sure with a small sanity test that this is a real IP address and - // not something that's likely to incur DNS lookups. - host = host[:i] - } - if ip := net.ParseIP(host); ip == nil { - return nil, errors.New("Failed to parse IP address: " + host) - } - - // parse address and port - - addr, err := net.ResolveUDPAddr("udp", s) - if err != nil { - return nil, err - } - ip4 := addr.IP.To4() - if ip4 != nil { - addr.IP = ip4 - } - return addr, err -} |