diff options
author | FUJITA Tomonori <fujita.tomonori@lab.ntt.co.jp> | 2017-02-21 13:51:26 +0900 |
---|---|---|
committer | FUJITA Tomonori <fujita.tomonori@lab.ntt.co.jp> | 2017-02-21 14:59:19 +0900 |
commit | 844cfba67ed62cbb13c91831672def55ccb8df60 (patch) | |
tree | 4dd80a65c66e5152f3e72fb1260f85db964647d3 /server/rpki.go | |
parent | fd010affbab5746fa5c0c55cd7803c3dbe2bbf9d (diff) |
rpki: use context instead of tomb
Let's use context, the standard way to handle cancellation.
Signed-off-by: FUJITA Tomonori <fujita.tomonori@lab.ntt.co.jp>
Diffstat (limited to 'server/rpki.go')
-rw-r--r-- | server/rpki.go | 93 |
1 files changed, 47 insertions, 46 deletions
diff --git a/server/rpki.go b/server/rpki.go index 9e81d360..5c8c3db9 100644 --- a/server/rpki.go +++ b/server/rpki.go @@ -30,7 +30,11 @@ import ( "github.com/osrg/gobgp/packet/bgp" "github.com/osrg/gobgp/packet/rtr" "github.com/osrg/gobgp/table" - "gopkg.in/tomb.v2" + "golang.org/x/net/context" +) + +const ( + CONNECT_RETRY_INTERVAL = 30 ) func before(a, b uint32) bool { @@ -129,9 +133,7 @@ func (m *roaManager) AddServer(host string, lifetime int64) error { if _, ok := m.clientMap[host]; ok { return fmt.Errorf("ROA server exists %s", host) } - client := NewRoaClient(address, port, m.eventCh, lifetime) - m.clientMap[host] = client - client.t.Go(client.tryConnect) + m.clientMap[host] = NewRoaClient(address, port, m.eventCh, lifetime) return nil } @@ -140,7 +142,8 @@ func (m *roaManager) DeleteServer(host string) error { if !ok { return fmt.Errorf("ROA server doesn't exists %s", host) } - client.reset() + client.stop() + m.deleteAllROA(host) delete(m.clientMap, host) return nil } @@ -185,6 +188,7 @@ func (m *roaManager) Disable(address string) error { add, _, _ := net.SplitHostPort(network) if add == address { client.reset() + m.deleteAllROA(add) return nil } } @@ -192,14 +196,7 @@ func (m *roaManager) Disable(address string) error { } func (m *roaManager) Reset(address string) error { - for network, client := range m.clientMap { - add, _, _ := net.SplitHostPort(network) - if add == address { - client.reset() - return nil - } - } - return fmt.Errorf("ROA server not found %s", address) + return m.Disable(address) } func (m *roaManager) SoftReset(address string) error { @@ -243,16 +240,14 @@ func (m *roaManager) HandleROAEvent(ev *ROAEvent) { client.pendingROAs = make([]*table.ROA, 0) client.state.RpkiMessages = config.RpkiMessages{} client.conn = nil - client.t = tomb.Tomb{} - client.t.Go(client.tryConnect) + go client.tryConnect() client.timer = time.AfterFunc(time.Duration(client.lifetime)*time.Second, client.lifetimeout) client.oldSessionID = client.sessionID case CONNECTED: log.WithFields(log.Fields{"Topic": "rpki"}).Infof("ROA server %s is connected", ev.Src) client.conn = ev.conn client.state.Uptime = time.Now().Unix() - client.t = tomb.Tomb{} - client.t.Go(client.established) + go client.established() case RTR: m.handleRTRMsg(client, &client.state, ev.Data) case LIFETIMEOUT: @@ -557,7 +552,6 @@ func (c *roaManager) validate(pathList []*table.Path) { } type roaClient struct { - t tomb.Tomb host string conn *net.TCPConn state config.RpkiServerState @@ -569,15 +563,22 @@ type roaClient struct { lifetime int64 endOfData bool pendingROAs []*table.ROA + cancelfnc context.CancelFunc + ctx context.Context } func NewRoaClient(address, port string, ch chan *ROAEvent, lifetime int64) *roaClient { - return &roaClient{ + ctx, cancel := context.WithCancel(context.Background()) + c := &roaClient{ host: net.JoinHostPort(address, port), eventCh: ch, lifetime: lifetime, pendingROAs: make([]*table.ROA, 0), + ctx: ctx, + cancelfnc: cancel, } + go c.tryConnect() + return c } func (c *roaClient) enable(serial uint32) error { @@ -609,60 +610,63 @@ func (c *roaClient) softReset() error { } func (c *roaClient) reset() { - c.t.Kill(nil) if c.conn != nil { c.conn.Close() } } -func (c *roaClient) tryConnect() error { - for c.t.Alive() { - conn, err := net.Dial("tcp", c.host) - if err != nil { - time.Sleep(30 * time.Second) +func (c *roaClient) stop() { + c.cancelfnc() + c.reset() +} + +func (c *roaClient) tryConnect() { + for { + select { + case <-c.ctx.Done(): + return + default: + } + if conn, err := net.Dial("tcp", c.host); err != nil { + // better to use context with timeout + time.Sleep(CONNECT_RETRY_INTERVAL * time.Second) } else { c.eventCh <- &ROAEvent{ EventType: CONNECTED, Src: c.host, conn: conn.(*net.TCPConn), } - return nil + return } } - return nil } -func (c *roaClient) established() error { - defer c.conn.Close() - - disconnected := func() { +func (c *roaClient) established() (err error) { + defer func() { + c.conn.Close() c.eventCh <- &ROAEvent{ EventType: DISCONNECTED, Src: c.host, } - } + }() - err := c.softReset() - if err != nil { - disconnected() - return nil + if err := c.softReset(); err != nil { + return err } for { header := make([]byte, rtr.RTR_MIN_LEN) - _, err := io.ReadFull(c.conn, header) - if err != nil { - break + if _, err = io.ReadFull(c.conn, header); err != nil { + return err } totalLen := binary.BigEndian.Uint32(header[4:8]) if totalLen < rtr.RTR_MIN_LEN { - break + return fmt.Errorf("too short header length %v", totalLen) } body := make([]byte, totalLen-rtr.RTR_MIN_LEN) - _, err = io.ReadFull(c.conn, body) - if err != nil { - break + if _, err = io.ReadFull(c.conn, body); err != nil { + return } c.eventCh <- &ROAEvent{ @@ -670,8 +674,5 @@ func (c *roaClient) established() error { Src: c.host, Data: append(header, body...), } - } - disconnected() - return nil } |