summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/link
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/link')
-rw-r--r--pkg/tcpip/link/tunnel/gre.go20
1 files changed, 18 insertions, 2 deletions
diff --git a/pkg/tcpip/link/tunnel/gre.go b/pkg/tcpip/link/tunnel/gre.go
index c62cfa283..64a54f11e 100644
--- a/pkg/tcpip/link/tunnel/gre.go
+++ b/pkg/tcpip/link/tunnel/gre.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/gre"
"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
@@ -126,10 +127,23 @@ func (info *GreHandlerInfo) greRead(ep *channel.Endpoint) {
}
}
-func (e *Endpoint) Start(s *stack.Stack, laddr, raddr *tcpip.Address) {
+func networkProtocolNumber(addr *tcpip.Address) tcpip.NetworkProtocolNumber {
+ if addr.To4() != "" {
+ return ipv4.ProtocolNumber
+ } else {
+ return ipv6.ProtocolNumber
+ }
+}
+
+func (e *Endpoint) Start(s *stack.Stack, laddr, raddr *tcpip.Address) *tcpip.Error {
+ proto := networkProtocolNumber(laddr)
+ if proto != networkProtocolNumber(raddr) {
+ return tcpip.ErrBadAddress
+ }
+
// Create TCP endpoint.
var rawWq waiter.Queue
- rawEp, tcperr := raw.NewEndpoint(s, ipv4.ProtocolNumber, header.GREProtocolNumber, &rawWq)
+ rawEp, tcperr := raw.NewEndpoint(s, proto, header.GREProtocolNumber, &rawWq)
if tcperr != nil {
log.Fatal(tcperr)
}
@@ -164,4 +178,6 @@ func (e *Endpoint) Start(s *stack.Stack, laddr, raddr *tcpip.Address) {
if tcperr != nil {
log.Fatal(tcperr)
}
+
+ return nil
}