// Copyright (C) 2016 Nippon Telegraph and Telephone Corporation.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
// implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build openbsd

package server

import (
	"encoding/binary"
	"fmt"
	"net"
	"os"
	"syscall"
	"unsafe"

	log "github.com/sirupsen/logrus"
)

const (
	PF_KEY_V2 = 2

	SADB_X_SATYPE_TCPSIGNATURE = 8

	SADB_EXT_SA          = 1
	SADB_EXT_ADDRESS_SRC = 5
	SADB_EXT_ADDRESS_DST = 6
	SADB_EXT_KEY_AUTH    = 8
	SADB_EXT_SPIRANGE    = 16

	SADB_GETSPI = 1
	SADB_UPDATE = 2
	SADB_DELETE = 4

	SADB_X_EALG_AES = 12

	SADB_SASTATE_MATURE = 1
)

type sadbMsg struct {
	sadbMsgVersion  uint8
	sadbMsgType     uint8
	sadbMsgErrno    uint8
	sadbMsgSatype   uint8
	sadbMsgLen      uint16
	sadbMsgReserved uint16
	sadbMsgSeq      uint32
	sadbMsgPid      uint32
}

func (s *sadbMsg) DecodeFromBytes(data []byte) error {
	if len(data) < SADB_MSG_SIZE {
		fmt.Errorf("too short for sadbMsg %d", len(data))
	}
	s.sadbMsgVersion = data[0]
	s.sadbMsgType = data[1]
	s.sadbMsgErrno = data[2]
	s.sadbMsgSatype = data[3]
	s.sadbMsgLen = binary.LittleEndian.Uint16(data[4:6])
	s.sadbMsgSeq = binary.LittleEndian.Uint32(data[8:12])
	s.sadbMsgPid = binary.LittleEndian.Uint32(data[12:16])
	return nil
}

type sadbSpirange struct {
	sadbSpirangeLen      uint16
	sadbSpirangeExttype  uint16
	sadbSpirangeMin      uint32
	sadbSpirangeMax      uint32
	sadbSpirangeReserved uint32
}

type sadbAddress struct {
	sadbAddressLen      uint16
	sadbAddressExttype  uint16
	sadbAddressReserved uint32
}

type sadbExt struct {
	sadbExtLen  uint16
	sadbExtType uint16
}

type sadbSa struct {
	sadbSaLen     uint16
	sadbSaExttype uint16
	sadbSaSpi     uint32
	sadbSaReplay  uint8
	sadbSaState   uint8
	sadbSaAuth    uint8
	sadbSaEncrypt uint8
	sadbSaFlags   uint32
}

type sadbKey struct {
	sadbKeyLen      uint16
	sadbKeyExttype  uint16
	sadbKeyBits     uint16
	sadbKeyReserved uint16
}

const (
	SADB_MSG_SIZE      = int(unsafe.Sizeof(sadbMsg{}))
	SADB_SPIRANGE_SIZE = int(unsafe.Sizeof(sadbSpirange{}))
	SADB_ADDRESS_SIZE  = int(unsafe.Sizeof(sadbAddress{}))
	SADB_SA_SIZE       = int(unsafe.Sizeof(sadbSa{}))
	SADB_KEY_SIZE      = int(unsafe.Sizeof(sadbKey{}))
)

type sockaddrIn struct {
	ssLen    uint8
	ssFamily uint8
	ssPort   uint16
	ssAddr   uint32
	pad      [8]byte
}

func newSockaddrIn(addr string) sockaddrIn {
	if len(addr) == 0 {
		return sockaddrIn{
			ssLen: 16,
		}
	}
	v := net.ParseIP(addr).To4()
	return sockaddrIn{
		ssAddr:   uint32(v[3])<<24 | uint32(v[2])<<16 | uint32(v[1])<<8 | uint32(v[0]),
		ssLen:    16,
		ssFamily: syscall.AF_INET,
	}
}

func roundUp(v int) int {
	if v%8 != 0 {
		v += 8 - v%8
	}
	return v
}

func b(p unsafe.Pointer, length int) []byte {
	buf := make([]byte, length)
	for i := 0; i < length; i++ {
		buf[i] = *(*byte)(p)
		p = unsafe.Pointer(uintptr(p) + 1)
	}
	return buf
}

var seq uint32
var fd int

var spiInMap map[string]uint32 = map[string]uint32{}
var spiOutMap map[string]uint32 = map[string]uint32{}

func pfkeyReply() (spi uint32, err error) {
	buf := make([]byte, SADB_MSG_SIZE)
	if count, _, _, _, _ := syscall.Recvmsg(fd, buf, nil, syscall.MSG_PEEK); count != len(buf) {
		return spi, fmt.Errorf("incomplete sadbmsg %d %d", len(buf), count)
	}
	h := sadbMsg{}
	h.DecodeFromBytes(buf)
	if h.sadbMsgErrno != 0 {
		return spi, fmt.Errorf("sadb msg reply error %d", h.sadbMsgErrno)
	}

	if h.sadbMsgSeq != seq {
		return spi, fmt.Errorf("sadb msg sequence doesn't match %d %d", h.sadbMsgSeq, seq)
	}

	if h.sadbMsgPid != uint32(os.Getpid()) {
		return spi, fmt.Errorf("sadb msg pid doesn't match %d %d", h.sadbMsgPid, os.Getpid())
	}

	buf = make([]byte, int(8*h.sadbMsgLen))
	if count, _, _, _, _ := syscall.Recvmsg(fd, buf, nil, 0); count != len(buf) {
		return spi, fmt.Errorf("incomplete sadbmsg body %d %d", len(buf), count)
	}

	buf = buf[SADB_MSG_SIZE:]

	for len(buf) >= 4 {
		l := binary.LittleEndian.Uint16(buf[0:2]) * 8
		t := binary.LittleEndian.Uint16(buf[2:4])
		if t == SADB_EXT_SA {
			return binary.LittleEndian.Uint32(buf[4:8]), nil
		}

		if len(buf) <= int(l) {
			break
		}
		buf = buf[l:]
	}
	return spi, err
}

func sendSadbMsg(msg *sadbMsg, body []byte) (err error) {
	if fd == 0 {
		fd, err = syscall.Socket(syscall.AF_KEY, syscall.SOCK_RAW, PF_KEY_V2)
		if err != nil {
			return err
		}
	}

	seq++
	msg.sadbMsgSeq = seq
	msg.sadbMsgLen = uint16((len(body) + SADB_MSG_SIZE) / 8)

	buf := append(b(unsafe.Pointer(msg), SADB_MSG_SIZE), body...)

	r, err := syscall.Write(fd, buf)
	if r != len(buf) {
		return fmt.Errorf("short write %d %d", r, len(buf))
	}
	return err
}

func rfkeyRequest(msgType uint8, src, dst string, spi uint32, key string) error {
	h := sadbMsg{
		sadbMsgVersion: PF_KEY_V2,
		sadbMsgType:    msgType,
		sadbMsgSatype:  SADB_X_SATYPE_TCPSIGNATURE,
		sadbMsgPid:     uint32(os.Getpid()),
	}

	ssrc := newSockaddrIn(src)
	sa_src := sadbAddress{
		sadbAddressExttype: SADB_EXT_ADDRESS_SRC,
		sadbAddressLen:     uint16(SADB_ADDRESS_SIZE+roundUp(int(ssrc.ssLen))) / 8,
	}

	sdst := newSockaddrIn(dst)
	sa_dst := sadbAddress{
		sadbAddressExttype: SADB_EXT_ADDRESS_DST,
		sadbAddressLen:     uint16(SADB_ADDRESS_SIZE+roundUp(int(sdst.ssLen))) / 8,
	}

	buf := make([]byte, 0)
	switch msgType {
	case SADB_UPDATE, SADB_DELETE:
		sa := sadbSa{
			sadbSaLen:     uint16(SADB_SA_SIZE / 8),
			sadbSaExttype: SADB_EXT_SA,
			sadbSaSpi:     spi,
			sadbSaState:   SADB_SASTATE_MATURE,
			sadbSaEncrypt: SADB_X_EALG_AES,
		}
		buf = append(buf, b(unsafe.Pointer(&sa), SADB_SA_SIZE)...)
	case SADB_GETSPI:
		spirange := sadbSpirange{
			sadbSpirangeLen:     uint16(SADB_SPIRANGE_SIZE) / 8,
			sadbSpirangeExttype: SADB_EXT_SPIRANGE,
			sadbSpirangeMin:     0x100,
			sadbSpirangeMax:     0xffffffff,
		}
		buf = append(buf, b(unsafe.Pointer(&spirange), SADB_SPIRANGE_SIZE)...)
	}

	buf = append(buf, b(unsafe.Pointer(&sa_dst), SADB_ADDRESS_SIZE)...)
	buf = append(buf, b(unsafe.Pointer(&sdst), roundUp(int(sdst.ssLen)))...)
	buf = append(buf, b(unsafe.Pointer(&sa_src), SADB_ADDRESS_SIZE)...)
	buf = append(buf, b(unsafe.Pointer(&ssrc), roundUp(int(ssrc.ssLen)))...)

	switch msgType {
	case SADB_UPDATE:
		keylen := roundUp(len(key))
		sa_akey := sadbKey{
			sadbKeyLen:     uint16((SADB_KEY_SIZE + keylen) / 8),
			sadbKeyExttype: SADB_EXT_KEY_AUTH,
			sadbKeyBits:    uint16(len(key) * 8),
		}
		k := []byte(key)
		if pad := keylen - len(k); pad != 0 {
			k = append(k, make([]byte, pad)...)
		}
		buf = append(buf, b(unsafe.Pointer(&sa_akey), SADB_KEY_SIZE)...)
		buf = append(buf, k...)
	}

	return sendSadbMsg(&h, buf)
}

func saAdd(address, key string) error {
	f := func(src, dst string) error {
		if err := rfkeyRequest(SADB_GETSPI, src, dst, 0, ""); err != nil {
			return err
		}
		spi, err := pfkeyReply()
		if err != nil {
			return err
		}
		if src == "" {
			spiOutMap[address] = spi
		} else {
			spiInMap[address] = spi
		}

		if err := rfkeyRequest(SADB_UPDATE, src, dst, spi, key); err != nil {
			return err
		}
		_, err = pfkeyReply()
		return err
	}

	if err := f(address, ""); err != nil {
		return err
	}

	return f("", address)
}

func saDelete(address string) error {
	if spi, y := spiInMap[address]; y {
		if err := rfkeyRequest(SADB_DELETE, address, "", spi, ""); err != nil {
			log.WithFields(log.Fields{
				"Topic": "Peer",
				"Key":   address,
			}).Info("failed to delete md5 for incoming")
		} else {
			log.WithFields(log.Fields{
				"Topic": "Peer",
				"Key":   address,
			}).Info("can't find spi for md5 for incoming")
		}
	}
	if spi, y := spiOutMap[address]; y {
		if err := rfkeyRequest(SADB_DELETE, "", address, spi, ""); err != nil {
			log.WithFields(log.Fields{
				"Topic": "Peer",
				"Key":   address,
			}).Info("failed to delete md5 for outgoing")
		} else {
			log.WithFields(log.Fields{
				"Topic": "Peer",
				"Key":   address,
			}).Info("can't find spi for md5 for outgoing")
		}
	}
	return nil
}

const (
	TCP_MD5SIG       = 0x4 // TCP MD5 Signature (RFC2385)
	IPV6_MINHOPCOUNT = 73  // Generalized TTL Security Mechanism (RFC5082)
)

func setsockoptTcpMD5Sig(fd int, address string, key string) error {
	if err := syscall.SetsockoptInt(fd, syscall.IPPROTO_TCP, TCP_MD5SIG, 1); err != nil {
		return os.NewSyscallError("setsockopt", err)
	}
	if len(key) > 0 {
		return saAdd(address, key)
	}
	return saDelete(address)
}

func SetTcpMD5SigSockopt(l *net.TCPListener, address string, key string) error {
	fi, _, err := extractFileAndFamilyFromTCPListener(l)
	defer fi.Close()
	if err != nil {
		return err
	}
	return setsockoptTcpMD5Sig(int(fi.Fd()), address, key)
}

func setsockoptIpTtl(fd int, family int, value int) error {
	level := syscall.IPPROTO_IP
	name := syscall.IP_TTL
	if family == syscall.AF_INET6 {
		level = syscall.IPPROTO_IPV6
		name = syscall.IPV6_UNICAST_HOPS
	}
	return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd, level, name, value))
}

func SetListenTcpTTLSockopt(l *net.TCPListener, ttl int) error {
	fi, family, err := extractFileAndFamilyFromTCPListener(l)
	defer fi.Close()
	if err != nil {
		return err
	}
	return setsockoptIpTtl(int(fi.Fd()), family, ttl)
}

func SetTcpTTLSockopt(conn *net.TCPConn, ttl int) error {
	fi, family, err := extractFileAndFamilyFromTCPConn(conn)
	defer fi.Close()
	if err != nil {
		return err
	}
	return setsockoptIpTtl(int(fi.Fd()), family, ttl)
}

func setsockoptIpMinTtl(fd int, family int, value int) error {
	level := syscall.IPPROTO_IP
	name := syscall.IP_MINTTL
	if family == syscall.AF_INET6 {
		level = syscall.IPPROTO_IPV6
		name = IPV6_MINHOPCOUNT
	}
	return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd, level, name, value))
}

func SetTcpMinTTLSockopt(conn *net.TCPConn, ttl int) error {
	fi, family, err := extractFileAndFamilyFromTCPConn(conn)
	defer fi.Close()
	if err != nil {
		return err
	}
	return setsockoptIpMinTtl(int(fi.Fd()), family, ttl)
}

type TCPDialer struct {
	net.Dialer

	// MD5 authentication password.
	AuthPassword string

	// The TTL value to set outgoing connection.
	Ttl uint8

	// The minimum TTL value for incoming packets.
	TtlMin uint8
}

func (d *TCPDialer) DialTCP(addr string, port int) (*net.TCPConn, error) {
	if d.AuthPassword != "" {
		log.WithFields(log.Fields{
			"Topic": "Peer",
			"Key":   addr,
		}).Warn("setting md5 for active connection is not supported")
	}
	if d.Ttl != 0 {
		log.WithFields(log.Fields{
			"Topic": "Peer",
			"Key":   addr,
		}).Warn("setting ttl for active connection is not supported")
	}
	if d.TtlMin != 0 {
		log.WithFields(log.Fields{
			"Topic": "Peer",
			"Key":   addr,
		}).Warn("setting min ttl for active connection is not supported")
	}

	raddr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(addr, fmt.Sprintf("%d", port)))
	if err != nil {
		return nil, fmt.Errorf("invalid remote address: %s", err)
	}
	laddr, err := net.ResolveTCPAddr("tcp", d.LocalAddr.String())
	if err != nil {
		return nil, fmt.Errorf("invalid local address: %s", err)
	}

	dialer := net.Dialer{LocalAddr: laddr, Timeout: d.Timeout}
	conn, err := dialer.Dial("tcp", raddr.String())
	if err != nil {
		return nil, err
	}
	return conn.(*net.TCPConn), nil
}