summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network/ipv4
diff options
context:
space:
mode:
authorIan Gudger <igudger@google.com>2019-10-07 19:28:26 -0700
committergVisor bot <gvisor-bot@google.com>2019-10-07 19:29:51 -0700
commit7c1587e3401a010d1865df61dbaf117c77dd062e (patch)
tree53392ccc3fc1d4cfa967f0d7f72e5920ed18fa5d /pkg/tcpip/network/ipv4
parent1de0cf3563502c1460964fc2fc9dca1ee447449a (diff)
Implement IP_TTL.
Also change the default TTL to 64 to match Linux. PiperOrigin-RevId: 273430341
Diffstat (limited to 'pkg/tcpip/network/ipv4')
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go2
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go38
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go4
3 files changed, 37 insertions, 7 deletions
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index a25756443..c1cf6c222 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -95,7 +95,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
pkt.SetChecksum(0)
pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0)))
sent := stats.ICMP.V4PacketsSent
- if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL()); err != nil {
+ if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv4ProtocolNumber, 0, true /* useDefaultTTL */); err != nil {
sent.Dropped.Increment()
return
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index b7b07a6c1..162aa1b4d 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -39,6 +39,9 @@ const (
// TotalLength field of the ipv4 header.
MaxTotalSize = 0xffff
+ // DefaultTTL is the default time-to-live value for this endpoint.
+ DefaultTTL = 64
+
// buckets is the number of identifier buckets.
buckets = 2048
)
@@ -70,7 +73,7 @@ func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWi
// DefaultTTL is the default time-to-live value for this endpoint.
func (e *endpoint) DefaultTTL() uint8 {
- return 255
+ return e.protocol.DefaultTTL()
}
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
@@ -327,6 +330,11 @@ func (e *endpoint) Close() {}
type protocol struct {
ids []uint32
hashIV uint32
+
+ // defaultTTL is the current default TTL for the protocol. Only the
+ // uint8 portion of it is meaningful and it must be accessed
+ // atomically.
+ defaultTTL uint32
}
// Number returns the ipv4 protocol number.
@@ -352,12 +360,34 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
// SetOption implements NetworkProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch v := option.(type) {
+ case tcpip.DefaultTTLOption:
+ p.SetDefaultTTL(uint8(v))
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// Option implements NetworkProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch v := option.(type) {
+ case *tcpip.DefaultTTLOption:
+ *v = tcpip.DefaultTTLOption(p.DefaultTTL())
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// SetDefaultTTL sets the default TTL for endpoints created with this protocol.
+func (p *protocol) SetDefaultTTL(ttl uint8) {
+ atomic.StoreUint32(&p.defaultTTL, uint32(ttl))
+}
+
+// DefaultTTL returns the default TTL for endpoints created with this protocol.
+func (p *protocol) DefaultTTL() uint8 {
+ return uint8(atomic.LoadUint32(&p.defaultTTL))
}
// calculateMTU calculates the network-layer payload MTU based on the link-layer
@@ -391,5 +421,5 @@ func NewProtocol() stack.NetworkProtocol {
}
hashIV := r[buckets]
- return &protocol{ids: ids, hashIV: hashIV}
+ return &protocol{ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL}
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index a53894c01..8b7500095 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -302,7 +302,7 @@ func TestFragmentation(t *testing.T) {
Payload: payload.Clone([]buffer.View{}),
}
c := buildContext(t, nil, ft.mtu)
- err := c.Route.WritePacket(ft.gso, hdr, payload, tcp.ProtocolNumber, 42)
+ err := c.Route.WritePacket(ft.gso, hdr, payload, tcp.ProtocolNumber, 42 /* ttl */, false /* useDefaultTTL */)
if err != nil {
t.Errorf("err got %v, want %v", err, nil)
}
@@ -349,7 +349,7 @@ func TestFragmentationErrors(t *testing.T) {
t.Run(ft.description, func(t *testing.T) {
hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
- err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, tcp.ProtocolNumber, 42)
+ err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, tcp.ProtocolNumber, 42 /* ttl */, false /* useDefaultTTL */)
for i := 0; i < len(ft.packetCollectorErrors)-1; i++ {
if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want {
t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want)