summaryrefslogtreecommitdiffhomepage
path: root/server
diff options
context:
space:
mode:
authorFUJITA Tomonori <fujita.tomonori@lab.ntt.co.jp>2017-02-21 13:51:26 +0900
committerFUJITA Tomonori <fujita.tomonori@lab.ntt.co.jp>2017-02-21 14:59:19 +0900
commit844cfba67ed62cbb13c91831672def55ccb8df60 (patch)
tree4dd80a65c66e5152f3e72fb1260f85db964647d3 /server
parentfd010affbab5746fa5c0c55cd7803c3dbe2bbf9d (diff)
rpki: use context instead of tomb
Let's use context, the standard way to handle cancellation. Signed-off-by: FUJITA Tomonori <fujita.tomonori@lab.ntt.co.jp>
Diffstat (limited to 'server')
-rw-r--r--server/rpki.go93
1 files changed, 47 insertions, 46 deletions
diff --git a/server/rpki.go b/server/rpki.go
index 9e81d360..5c8c3db9 100644
--- a/server/rpki.go
+++ b/server/rpki.go
@@ -30,7 +30,11 @@ import (
"github.com/osrg/gobgp/packet/bgp"
"github.com/osrg/gobgp/packet/rtr"
"github.com/osrg/gobgp/table"
- "gopkg.in/tomb.v2"
+ "golang.org/x/net/context"
+)
+
+const (
+ CONNECT_RETRY_INTERVAL = 30
)
func before(a, b uint32) bool {
@@ -129,9 +133,7 @@ func (m *roaManager) AddServer(host string, lifetime int64) error {
if _, ok := m.clientMap[host]; ok {
return fmt.Errorf("ROA server exists %s", host)
}
- client := NewRoaClient(address, port, m.eventCh, lifetime)
- m.clientMap[host] = client
- client.t.Go(client.tryConnect)
+ m.clientMap[host] = NewRoaClient(address, port, m.eventCh, lifetime)
return nil
}
@@ -140,7 +142,8 @@ func (m *roaManager) DeleteServer(host string) error {
if !ok {
return fmt.Errorf("ROA server doesn't exists %s", host)
}
- client.reset()
+ client.stop()
+ m.deleteAllROA(host)
delete(m.clientMap, host)
return nil
}
@@ -185,6 +188,7 @@ func (m *roaManager) Disable(address string) error {
add, _, _ := net.SplitHostPort(network)
if add == address {
client.reset()
+ m.deleteAllROA(add)
return nil
}
}
@@ -192,14 +196,7 @@ func (m *roaManager) Disable(address string) error {
}
func (m *roaManager) Reset(address string) error {
- for network, client := range m.clientMap {
- add, _, _ := net.SplitHostPort(network)
- if add == address {
- client.reset()
- return nil
- }
- }
- return fmt.Errorf("ROA server not found %s", address)
+ return m.Disable(address)
}
func (m *roaManager) SoftReset(address string) error {
@@ -243,16 +240,14 @@ func (m *roaManager) HandleROAEvent(ev *ROAEvent) {
client.pendingROAs = make([]*table.ROA, 0)
client.state.RpkiMessages = config.RpkiMessages{}
client.conn = nil
- client.t = tomb.Tomb{}
- client.t.Go(client.tryConnect)
+ go client.tryConnect()
client.timer = time.AfterFunc(time.Duration(client.lifetime)*time.Second, client.lifetimeout)
client.oldSessionID = client.sessionID
case CONNECTED:
log.WithFields(log.Fields{"Topic": "rpki"}).Infof("ROA server %s is connected", ev.Src)
client.conn = ev.conn
client.state.Uptime = time.Now().Unix()
- client.t = tomb.Tomb{}
- client.t.Go(client.established)
+ go client.established()
case RTR:
m.handleRTRMsg(client, &client.state, ev.Data)
case LIFETIMEOUT:
@@ -557,7 +552,6 @@ func (c *roaManager) validate(pathList []*table.Path) {
}
type roaClient struct {
- t tomb.Tomb
host string
conn *net.TCPConn
state config.RpkiServerState
@@ -569,15 +563,22 @@ type roaClient struct {
lifetime int64
endOfData bool
pendingROAs []*table.ROA
+ cancelfnc context.CancelFunc
+ ctx context.Context
}
func NewRoaClient(address, port string, ch chan *ROAEvent, lifetime int64) *roaClient {
- return &roaClient{
+ ctx, cancel := context.WithCancel(context.Background())
+ c := &roaClient{
host: net.JoinHostPort(address, port),
eventCh: ch,
lifetime: lifetime,
pendingROAs: make([]*table.ROA, 0),
+ ctx: ctx,
+ cancelfnc: cancel,
}
+ go c.tryConnect()
+ return c
}
func (c *roaClient) enable(serial uint32) error {
@@ -609,60 +610,63 @@ func (c *roaClient) softReset() error {
}
func (c *roaClient) reset() {
- c.t.Kill(nil)
if c.conn != nil {
c.conn.Close()
}
}
-func (c *roaClient) tryConnect() error {
- for c.t.Alive() {
- conn, err := net.Dial("tcp", c.host)
- if err != nil {
- time.Sleep(30 * time.Second)
+func (c *roaClient) stop() {
+ c.cancelfnc()
+ c.reset()
+}
+
+func (c *roaClient) tryConnect() {
+ for {
+ select {
+ case <-c.ctx.Done():
+ return
+ default:
+ }
+ if conn, err := net.Dial("tcp", c.host); err != nil {
+ // better to use context with timeout
+ time.Sleep(CONNECT_RETRY_INTERVAL * time.Second)
} else {
c.eventCh <- &ROAEvent{
EventType: CONNECTED,
Src: c.host,
conn: conn.(*net.TCPConn),
}
- return nil
+ return
}
}
- return nil
}
-func (c *roaClient) established() error {
- defer c.conn.Close()
-
- disconnected := func() {
+func (c *roaClient) established() (err error) {
+ defer func() {
+ c.conn.Close()
c.eventCh <- &ROAEvent{
EventType: DISCONNECTED,
Src: c.host,
}
- }
+ }()
- err := c.softReset()
- if err != nil {
- disconnected()
- return nil
+ if err := c.softReset(); err != nil {
+ return err
}
for {
header := make([]byte, rtr.RTR_MIN_LEN)
- _, err := io.ReadFull(c.conn, header)
- if err != nil {
- break
+ if _, err = io.ReadFull(c.conn, header); err != nil {
+ return err
}
totalLen := binary.BigEndian.Uint32(header[4:8])
if totalLen < rtr.RTR_MIN_LEN {
- break
+ return fmt.Errorf("too short header length %v", totalLen)
}
body := make([]byte, totalLen-rtr.RTR_MIN_LEN)
- _, err = io.ReadFull(c.conn, body)
- if err != nil {
- break
+ if _, err = io.ReadFull(c.conn, body); err != nil {
+ return
}
c.eventCh <- &ROAEvent{
@@ -670,8 +674,5 @@ func (c *roaClient) established() error {
Src: c.host,
Data: append(header, body...),
}
-
}
- disconnected()
- return nil
}