diff options
-rw-r--r-- | ryu/lib/packet/icmp.py | 23 | ||||
-rw-r--r-- | ryu/lib/packet/icmpv6.py | 27 | ||||
-rw-r--r-- | ryu/tests/unit/packet/test_icmp.py | 2 | ||||
-rw-r--r-- | ryu/tests/unit/packet/test_icmpv6.py | 2 | ||||
-rw-r--r-- | ryu/tests/unit/packet/test_packet.py | 2 |
5 files changed, 36 insertions, 20 deletions
diff --git a/ryu/lib/packet/icmp.py b/ryu/lib/packet/icmp.py index 1a6cd76f..4ce24c3e 100644 --- a/ryu/lib/packet/icmp.py +++ b/ryu/lib/packet/icmp.py @@ -13,8 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import struct +import six + from . import packet_base from . import packet_utils from ryu.lib import stringify @@ -77,7 +80,7 @@ class icmp(packet_base.PacketBase): return cls return _register_icmp_type - def __init__(self, type_=ICMP_ECHO_REQUEST, code=0, csum=0, data=None): + def __init__(self, type_=ICMP_ECHO_REQUEST, code=0, csum=0, data=b''): super(icmp, self).__init__() self.type = type_ self.code = code @@ -103,8 +106,9 @@ class icmp(packet_base.PacketBase): hdr = bytearray(struct.pack(icmp._PACK_STR, self.type, self.code, self.csum)) - if self.data is not None: + if self.data: if self.type in icmp._ICMP_TYPES: + assert isinstance(self.data, _ICMPv4Payload) hdr += self.data.serialize() else: hdr += self.data @@ -122,8 +126,15 @@ class icmp(packet_base.PacketBase): return self._MIN_LEN + len(self.data) +@six.add_metaclass(abc.ABCMeta) +class _ICMPv4Payload(stringify.StringifyMixin): + """ + Base class for the payload of ICMPv4 packet. + """ + + @icmp.register_icmp_type(ICMP_ECHO_REPLY, ICMP_ECHO_REQUEST) -class echo(stringify.StringifyMixin): +class echo(_ICMPv4Payload): """ICMP sub encoder/decoder class for Echo and Echo Reply messages. This is used with ryu.lib.packet.icmp.icmp for @@ -181,7 +192,7 @@ class echo(stringify.StringifyMixin): @icmp.register_icmp_type(ICMP_DEST_UNREACH) -class dest_unreach(stringify.StringifyMixin): +class dest_unreach(_ICMPv4Payload): """ICMP sub encoder/decoder class for Destination Unreachable Message. This is used with ryu.lib.packet.icmp.icmp for @@ -252,7 +263,7 @@ class dest_unreach(stringify.StringifyMixin): @icmp.register_icmp_type(ICMP_TIME_EXCEEDED) -class TimeExceeded(stringify.StringifyMixin): +class TimeExceeded(_ICMPv4Payload): """ICMP sub encoder/decoder class for Time Exceeded Message. This is used with ryu.lib.packet.icmp.icmp for @@ -278,7 +289,7 @@ class TimeExceeded(stringify.StringifyMixin): _MIN_LEN = struct.calcsize(_PACK_STR) def __init__(self, data_len=0, data=None): - if ((data_len >= 0) and (data_len <= 255)): + if (data_len >= 0) and (data_len <= 255): self.data_len = data_len else: raise ValueError('Specified data length (%d) is invalid.' % data_len) diff --git a/ryu/lib/packet/icmpv6.py b/ryu/lib/packet/icmpv6.py index fe94c77f..53d5ec3a 100644 --- a/ryu/lib/packet/icmpv6.py +++ b/ryu/lib/packet/icmpv6.py @@ -112,7 +112,7 @@ class icmpv6(packet_base.PacketBase): return cls return _register_icmpv6_type - def __init__(self, type_=0, code=0, csum=0, data=None): + def __init__(self, type_=0, code=0, csum=0, data=b''): super(icmpv6, self).__init__() self.type_ = type_ self.code = code @@ -137,8 +137,9 @@ class icmpv6(packet_base.PacketBase): hdr = bytearray(struct.pack(icmpv6._PACK_STR, self.type_, self.code, self.csum)) - if self.data is not None: + if self.data: if self.type_ in icmpv6._ICMPV6_TYPES: + assert isinstance(self.data, _ICMPv6Payload) hdr += self.data.serialize() else: hdr += self.data @@ -149,14 +150,18 @@ class icmpv6(packet_base.PacketBase): return hdr def __len__(self): - length = self._MIN_LEN - if self.data is not None: - length += len(self.data) - return length + return self._MIN_LEN + len(self.data) + + +@six.add_metaclass(abc.ABCMeta) +class _ICMPv6Payload(stringify.StringifyMixin): + """ + Base class for the payload of ICMPv6 packet. + """ @icmpv6.register_icmpv6_type(ND_NEIGHBOR_SOLICIT, ND_NEIGHBOR_ADVERT) -class nd_neighbor(stringify.StringifyMixin): +class nd_neighbor(_ICMPv6Payload): """ICMPv6 sub encoder/decoder class for Neighbor Solicitation and Neighbor Advertisement messages. (RFC 4861) @@ -237,7 +242,7 @@ class nd_neighbor(stringify.StringifyMixin): @icmpv6.register_icmpv6_type(ND_ROUTER_SOLICIT) -class nd_router_solicit(stringify.StringifyMixin): +class nd_router_solicit(_ICMPv6Payload): """ICMPv6 sub encoder/decoder class for Router Solicitation messages. (RFC 4861) @@ -308,7 +313,7 @@ class nd_router_solicit(stringify.StringifyMixin): @icmpv6.register_icmpv6_type(ND_ROUTER_ADVERT) -class nd_router_advert(stringify.StringifyMixin): +class nd_router_advert(_ICMPv6Payload): """ICMPv6 sub encoder/decoder class for Router Advertisement messages. (RFC 4861) @@ -619,7 +624,7 @@ class nd_option_pi(nd_option): @icmpv6.register_icmpv6_type(ICMPV6_ECHO_REPLY, ICMPV6_ECHO_REQUEST) -class echo(stringify.StringifyMixin): +class echo(_ICMPv6Payload): """ICMPv6 sub encoder/decoder class for Echo Request and Echo Reply messages. @@ -675,7 +680,7 @@ class echo(stringify.StringifyMixin): @icmpv6.register_icmpv6_type( MLD_LISTENER_QUERY, MLD_LISTENER_REPOR, MLD_LISTENER_DONE) -class mld(stringify.StringifyMixin): +class mld(_ICMPv6Payload): """ICMPv6 sub encoder/decoder class for MLD Lister Query, MLD Listener Report, and MLD Listener Done messages. (RFC 2710) diff --git a/ryu/tests/unit/packet/test_icmp.py b/ryu/tests/unit/packet/test_icmp.py index f9438893..ca96b262 100644 --- a/ryu/tests/unit/packet/test_icmp.py +++ b/ryu/tests/unit/packet/test_icmp.py @@ -45,7 +45,7 @@ class Test_icmp(unittest.TestCase): self.type_ = icmp.ICMP_ECHO_REQUEST self.code = 0 self.csum = 0 - self.data = None + self.data = b'' self.ic = icmp.icmp(self.type_, self.code, self.csum, self.data) diff --git a/ryu/tests/unit/packet/test_icmpv6.py b/ryu/tests/unit/packet/test_icmpv6.py index c6438171..9bb7482b 100644 --- a/ryu/tests/unit/packet/test_icmpv6.py +++ b/ryu/tests/unit/packet/test_icmpv6.py @@ -69,7 +69,7 @@ class Test_icmpv6_header(unittest.TestCase): eq_(msg.type_, self.type_) eq_(msg.code, self.code) eq_(msg.csum, self.csum) - eq_(msg.data, None) + eq_(msg.data, b'') eq_(n, None) def test_serialize(self): diff --git a/ryu/tests/unit/packet/test_packet.py b/ryu/tests/unit/packet/test_packet.py index 1b4c704b..80d22143 100644 --- a/ryu/tests/unit/packet/test_packet.py +++ b/ryu/tests/unit/packet/test_packet.py @@ -1365,7 +1365,7 @@ class TestPacket(unittest.TestCase): icmpv6_values = {'type_': 0, 'code': 0, 'csum': p_icmpv6.csum, - 'data': None} + 'data': b''} _icmpv6_str = ','.join(['%s=%s' % (k, repr(icmpv6_values[k])) for k, _ in inspect.getmembers(p_icmpv6) if k in icmpv6_values]) |