package main import ( "context" "fmt" "io" "net" "net/netip" "net/url" "os" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" gen "golang.zx2c4.com/wireguard/android/gen" "golang.zx2c4.com/wireguard/device" ) const ( IPPROTO_TCP = 6 ) type UidRequest struct { Data AddrPortPair RetCh chan string } type LibwgServiceImpl struct { gen.UnimplementedLibwgServer logger *device.Logger httpProxy *HttpProxy uidRequest chan UidRequest stopReverse chan bool } var service *LibwgServiceImpl var server *grpc.Server func NewLibwgService(logger *device.Logger) gen.LibwgServer { return &LibwgServiceImpl{ logger: logger, uidRequest: make(chan UidRequest), stopReverse: make(chan bool), } } func StartGrpc(sock_path string, logger *device.Logger) (int, string) { if server != nil { return -1, "Already started" } if _, err := os.Stat(sock_path); err == nil { if err := os.RemoveAll(sock_path); err != nil { return -1, fmt.Sprintf("Cleanup failed: %v %v", sock_path, err) } } listener, err := net.Listen("unix", sock_path) if err != nil { return -1, fmt.Sprintf("Listen failed: %v %v", sock_path, err) } server = grpc.NewServer() service = NewLibwgService(logger).(*LibwgServiceImpl) gen.RegisterLibwgServer(server, service) go func() { server.Serve(listener) }() logger.Verbosef("gRPC started") return 0, "" } func (e *LibwgServiceImpl) Version(ctx context.Context, req *gen.VersionRequest) (*gen.VersionResponse, error) { r := &gen.VersionResponse{ Version: Version(), } return r, nil } func (e *LibwgServiceImpl) StopGrpc(ctx context.Context, req *gen.StopGrpcRequest) (*gen.StopGrpcResponse, error) { if server != nil { server.Stop() server = nil service = nil } r := &gen.StopGrpcResponse{ } return r, nil } func buildStartHttpProxyError(message string) (*gen.StartHttpProxyResponse, error) { r := &gen.StartHttpProxyResponse{ Error: &gen.Error{ Message: message, }, } return r, nil } func (e *LibwgServiceImpl) StartHttpProxy(ctx context.Context, req *gen.StartHttpProxyRequest) (*gen.StartHttpProxyResponse, error) { var listenPort uint16 if e.httpProxy == nil { e.httpProxy = NewHttpProxy(e.uidRequest, e.logger) var err error listenPort, err = e.httpProxy.Start() if err != nil { e.httpProxy = nil return buildStartHttpProxyError(fmt.Sprintf("Http proxy start failed: %v", err)) } } else { listenPort = e.httpProxy.GetAddrPort().Port() } pacFileUrl, err := url.Parse(req.PacFileUrl) if err != nil { return buildStartHttpProxyError(fmt.Sprintf("Bad pacFileUrl: %v (%s)", err, req.PacFileUrl)) } err = e.httpProxy.SetPacFileUrl(pacFileUrl) if err != nil { return buildStartHttpProxyError(fmt.Sprintf("Bad pacFileUrl: %v (%s)", req.PacFileUrl)) } r := &gen.StartHttpProxyResponse{ ListenPort: uint32(listenPort), } return r, nil } func (e *LibwgServiceImpl) StopHttpProxy(ctx context.Context, req *gen.StopHttpProxyRequest) (*gen.StopHttpProxyResponse, error) { if e.httpProxy == nil { r := &gen.StopHttpProxyResponse{ Error: &gen.Error{ Message: fmt.Sprintf("Http proxy not running"), }, } return r, nil } e.httpProxy.Stop() e.httpProxy = nil e.stopReverse <- true r := &gen.StopHttpProxyResponse{} return r, nil } func (e *LibwgServiceImpl) Reverse(stream gen.Libwg_ReverseServer) error { e.logger.Verbosef("Reverse enter loop") for e.httpProxy != nil { var err error // err := contextError(stream.Context()) err = stream.Context().Err() if err != nil { e.logger.Verbosef("Reverse: context: %v", err) return err } select { case <-e.stopReverse: e.logger.Verbosef("Reverse: stop") break case uidReq := <-e.uidRequest: addrPortPair := uidReq.Data local := addrPortPair.local remote := addrPortPair.remote r := &gen.ReverseResponse{ Request: &gen.ReverseResponse_Uid{ Uid: &gen.GetConnectionOwnerUidRequest{ Protocol: IPPROTO_TCP, Local: &gen.InetSocketAddress{ Address: &gen.InetAddress{ Address: local.Addr().AsSlice(), }, Port: uint32(local.Port()), }, Remote: &gen.InetSocketAddress{ Address: &gen.InetAddress{ Address: remote.Addr().AsSlice(), }, Port: uint32(remote.Port()), }, }, }, } stream.Send(r) req, err := stream.Recv() if err == io.EOF { e.logger.Verbosef("no more data") uidReq.RetCh <- "" break } if err != nil { err = status.Errorf(codes.Unknown, "cannot receive stream request: %v", err) e.logger.Verbosef("Reverse: %v", err) uidReq.RetCh <- "" return err } e.logger.Verbosef("Reverse: received, wait: %v", req) uidReq.RetCh <- req.GetUid().GetPackage() } } e.logger.Verbosef("Reverse returns") return nil } func (e *LibwgServiceImpl) IpcSet(ctx context.Context, req *gen.IpcSetRequest) (*gen.IpcSetResponse, error) { tunnel, ok := GetTunnel(req.GetTunnel().GetHandle()) if !ok { r := &gen.IpcSetResponse{ Error: &gen.Error{ Message: fmt.Sprintf("Invalid tunnel"), }, } return r, nil } err := tunnel.device.IpcSet(req.GetConfig()) if err != nil { r := &gen.IpcSetResponse{ Error: &gen.Error{ Message: fmt.Sprintf("IpcSet failed: %v", err), }, } return r, nil } r := &gen.IpcSetResponse{ } return r, nil } func (e *LibwgServiceImpl) Dhcp(ctx context.Context, req *gen.DhcpRequest) (*gen.DhcpResponse, error) { var relayAddr netip.Addr var sourceAddr netip.Addr source := req.GetSource() if source != nil { sourceAddr, _ = netip.AddrFromSlice(source.GetAddress()) } if !sourceAddr.IsValid() || !sourceAddr.Is6() { r := &gen.DhcpResponse{ Error: &gen.Error{ Message: fmt.Sprintf("DHCPv6 source address missing"), }, } return r, nil } relay := req.GetRelay() if relay != nil { relayAddr, _ = netip.AddrFromSlice(relay.GetAddress()) } else { // Construct relay address from source prefix relayRaw := source.GetAddress()[:8] relayRaw = append(relayRaw, 0) relayRaw = append(relayRaw, 0) relayRaw = append(relayRaw, 0) relayRaw = append(relayRaw, 0) relayRaw = append(relayRaw, 0) relayRaw = append(relayRaw, 0) relayRaw = append(relayRaw, 0) relayRaw = append(relayRaw, 1) relayAddr, _ = netip.AddrFromSlice(relayRaw) } if !relayAddr.IsValid() || !relayAddr.Is6() { r := &gen.DhcpResponse{ Error: &gen.Error{ Message: fmt.Sprintf("DHCPv6 relay address calculation failed"), }, } return r, nil } e.logger.Verbosef("RunDhcp %v %v", sourceAddr, relayAddr) leases, err := RunDhcp(ctx, sourceAddr, relayAddr) if err != nil { r := &gen.DhcpResponse{ Error: &gen.Error{ Message: fmt.Sprintf("RunDhcp failed: %v", err), }, } return r, nil } r := &gen.DhcpResponse{ Leases: leases, } return r, nil }