// Copyright 2017 The Netstack Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package tcp

import (
	"fmt"
	"sync"

	"gvisor.googlesource.com/gvisor/pkg/tcpip"
	"gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
)

func (e *endpoint) drainSegmentLocked() {
	// Drain only up to once.
	if e.drainDone != nil {
		return
	}

	e.drainDone = make(chan struct{})
	e.undrain = make(chan struct{})
	e.mu.Unlock()

	e.notificationWaker.Assert()
	<-e.drainDone

	e.mu.Lock()
}

// beforeSave is invoked by stateify.
func (e *endpoint) beforeSave() {
	// Stop incoming packets.
	e.segmentQueue.setLimit(0)

	e.mu.Lock()
	defer e.mu.Unlock()

	switch e.state {
	case stateInitial, stateBound:
	case stateListen:
		if !e.segmentQueue.empty() {
			e.drainSegmentLocked()
		}
	case stateConnecting:
		e.drainSegmentLocked()
		if e.state != stateConnected {
			break
		}
		fallthrough
	case stateConnected:
		// FIXME
		panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)})
	case stateClosed, stateError:
		if e.workerRunning {
			panic(fmt.Sprintf("endpoint still has worker running in closed or error state"))
		}
	default:
		panic(fmt.Sprintf("endpoint in unknown state %v", e.state))
	}
}

// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
	e.stack = stack.StackFromEnv
	e.segmentQueue.setLimit(2 * e.rcvBufSize)
	e.workMu.Init()

	state := e.state
	switch state {
	case stateInitial, stateBound, stateListen, stateConnecting:
		var ss SendBufferSizeOption
		if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
			if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
				panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
			}
			if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max {
				panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max))
			}
		}
	}

	switch state {
	case stateBound, stateListen, stateConnecting:
		e.state = stateInitial
		if err := e.Bind(tcpip.FullAddress{Addr: e.id.LocalAddress, Port: e.id.LocalPort}, nil); err != nil {
			panic("endpoint binding failed: " + err.String())
		}
	}

	switch state {
	case stateListen:
		backlog := cap(e.acceptedChan)
		e.acceptedChan = nil
		if err := e.Listen(backlog); err != nil {
			panic("endpoint listening failed: " + err.String())
		}
	}

	switch state {
	case stateConnecting:
		if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted {
			panic("endpoint connecting failed: " + err.String())
		}
	}
}

// saveAcceptedChan is invoked by stateify.
func (e *endpoint) saveAcceptedChan() endpointChan {
	if e.acceptedChan == nil {
		return endpointChan{}
	}
	close(e.acceptedChan)
	buffer := make([]*endpoint, 0, len(e.acceptedChan))
	for ep := range e.acceptedChan {
		buffer = append(buffer, ep)
	}
	if len(buffer) != cap(buffer) {
		panic("endpoint.acceptedChan buffer got consumed by background context")
	}
	c := cap(e.acceptedChan)
	e.acceptedChan = nil
	return endpointChan{buffer: buffer, cap: c}
}

// loadAcceptedChan is invoked by stateify.
func (e *endpoint) loadAcceptedChan(c endpointChan) {
	if c.cap == 0 {
		return
	}
	e.acceptedChan = make(chan *endpoint, c.cap)
	for _, ep := range c.buffer {
		e.acceptedChan <- ep
	}
}

type endpointChan struct {
	buffer []*endpoint
	cap    int
}

// saveLastError is invoked by stateify.
func (e *endpoint) saveLastError() string {
	if e.lastError == nil {
		return ""
	}

	return e.lastError.String()
}

// loadLastError is invoked by stateify.
func (e *endpoint) loadLastError(s string) {
	if s == "" {
		return
	}

	e.lastError = loadError(s)
}

// saveHardError is invoked by stateify.
func (e *endpoint) saveHardError() string {
	if e.hardError == nil {
		return ""
	}

	return e.hardError.String()
}

// loadHardError is invoked by stateify.
func (e *endpoint) loadHardError(s string) {
	if s == "" {
		return
	}

	e.hardError = loadError(s)
}

var messageToError map[string]*tcpip.Error

var populate sync.Once

func loadError(s string) *tcpip.Error {
	populate.Do(func() {
		var errors = []*tcpip.Error{
			tcpip.ErrUnknownProtocol,
			tcpip.ErrUnknownNICID,
			tcpip.ErrUnknownProtocolOption,
			tcpip.ErrDuplicateNICID,
			tcpip.ErrDuplicateAddress,
			tcpip.ErrNoRoute,
			tcpip.ErrBadLinkEndpoint,
			tcpip.ErrAlreadyBound,
			tcpip.ErrInvalidEndpointState,
			tcpip.ErrAlreadyConnecting,
			tcpip.ErrAlreadyConnected,
			tcpip.ErrNoPortAvailable,
			tcpip.ErrPortInUse,
			tcpip.ErrBadLocalAddress,
			tcpip.ErrClosedForSend,
			tcpip.ErrClosedForReceive,
			tcpip.ErrWouldBlock,
			tcpip.ErrConnectionRefused,
			tcpip.ErrTimeout,
			tcpip.ErrAborted,
			tcpip.ErrConnectStarted,
			tcpip.ErrDestinationRequired,
			tcpip.ErrNotSupported,
			tcpip.ErrQueueSizeNotSupported,
			tcpip.ErrNotConnected,
			tcpip.ErrConnectionReset,
			tcpip.ErrConnectionAborted,
			tcpip.ErrNoSuchFile,
			tcpip.ErrInvalidOptionValue,
			tcpip.ErrNoLinkAddress,
			tcpip.ErrBadAddress,
			tcpip.ErrNetworkUnreachable,
		}

		messageToError = make(map[string]*tcpip.Error)
		for _, e := range errors {
			if messageToError[e.String()] != nil {
				panic("tcpip errors with duplicated message: " + e.String())
			}
			messageToError[e.String()] = e
		}
	})

	e, ok := messageToError[s]
	if !ok {
		panic("unknown error message: " + s)
	}

	return e
}