diff options
-rw-r--r-- | ryu/app/rest_router.py | 47 |
1 files changed, 36 insertions, 11 deletions
diff --git a/ryu/app/rest_router.py b/ryu/app/rest_router.py index 31ef91f8..c2c53297 100644 --- a/ryu/app/rest_router.py +++ b/ryu/app/rest_router.py @@ -1252,8 +1252,9 @@ class AddressData(dict): for other in self.values(): other_mask = mask_ntob(other.netmask) add_mask = mask_ntob(mask, err_msg=err_msg) - if (other.nw_addr == default_gw & other_mask - or nw_addr == other.default_gw & add_mask): + if (other.nw_addr == ipv4_apply_mask(default_gw, other.netmask) or + nw_addr == ipv4_apply_mask(other.default_gw, mask, + err_msg)): msg = 'Address overlaps [address_id=%d]' % other.address_id raise CommandFailure(msg=msg) @@ -1285,7 +1286,7 @@ class AddressData(dict): return address else: assert ip is not None - if ip & mask_ntob(address.netmask) == address.nw_addr: + if ipv4_apply_mask(ip, address.netmask) == address.nw_addr: return address return None @@ -1299,7 +1300,7 @@ class Address(object): self.default_gw = default_gw def __contains__(self, ip): - return bool(ip & mask_ntob(self.netmask) == self.nw_addr) + return bool(ipv4_apply_mask(ip, self.netmask) == self.nw_addr) class RoutingTable(dict): @@ -1363,7 +1364,7 @@ class RoutingTable(dict): get_route = None mask = 0 for route in self.values(): - if dst_ip & mask_ntob(route.netmask) == route.dst_ip: + if ipv4_apply_mask(dst_ip, route.netmask) == route.dst_ip: # For longest match if mask < route.netmask: get_route = route @@ -1634,7 +1635,9 @@ class OfCtl_v1_0(OfCtl): wildcards &= ~ofp.OFPFW_NW_PROTO match = ofp_parser.OFPMatch(wildcards, 0, 0, dl_dst, dl_vlan, 0, - dl_type, 0, nw_proto, nw_src, nw_dst, + dl_type, 0, nw_proto, + ipv4_bytes_to_int(nw_src), + ipv4_bytes_to_int(nw_dst), 0, 0) actions = actions or [] @@ -1725,9 +1728,11 @@ class OfCtl_v1_2(OfCtl): if dl_vlan: match.set_vlan_vid(dl_vlan) if nw_src: - match.set_ipv4_src_masked(nw_src, mask_ntob(src_mask)) + match.set_ipv4_src_masked(ipv4_bytes_to_int(nw_src), + mask_ntob(src_mask)) if nw_dst: - match.set_ipv4_dst_masked(nw_dst, mask_ntob(dst_mask)) + match.set_ipv4_dst_masked(ipv4_bytes_to_int(nw_dst), + mask_ntob(dst_mask)) if nw_proto: if dl_type == ether.ETH_TYPE_IP: match.set_ip_proto(nw_proto) @@ -1790,7 +1795,7 @@ class OfCtl_v1_2(OfCtl): def ip_addr_aton(ip_str, err_msg=None): try: - return struct.unpack('!I', socket.inet_aton(ip_str))[0] + return socket.inet_aton(ip_str) except (struct.error, socket.error) as e: if err_msg is not None: e.message = '%s %s' % (err_msg, e.message) @@ -1798,7 +1803,7 @@ def ip_addr_aton(ip_str, err_msg=None): def ip_addr_ntoa(ip): - return socket.inet_ntoa(struct.pack('!I', ip)) + return socket.inet_ntoa(ip) def mask_ntob(mask, err_msg=None): @@ -1811,6 +1816,26 @@ def mask_ntob(mask, err_msg=None): raise ValueError(msg) +def ipv4_apply_mask(address, prefix_len, err_msg=None): + import itertools + + assert isinstance(address, bytes) + assert len(address) == 4 + mask = ipv4_int_to_bytes(mask_ntob(prefix_len, err_msg)) + return ''.join(chr(ord(x) & ord(y)) for (x, y) in + itertools.izip(address, mask)) + + +def ipv4_int_to_bytes(ip_int): + assert isinstance(ip_int, (int, long)) + return struct.pack('!I', ip_int) + + +def ipv4_bytes_to_int(ip_bytes): + assert isinstance(ip_bytes, bytes) + return struct.unpack('!I', ip_bytes)[0] + + def nw_addr_aton(nw_addr, err_msg=None): ip_mask = nw_addr.split('/') default_route = ip_addr_aton(ip_mask[0], err_msg=err_msg) @@ -1827,5 +1852,5 @@ def nw_addr_aton(nw_addr, err_msg=None): if err_msg is not None: msg = '%s %s' % (err_msg, msg) raise ValueError(msg) - nw_addr = default_route & mask_ntob(netmask, err_msg=err_msg) + nw_addr = ipv4_apply_mask(default_route, netmask, err_msg) return nw_addr, netmask, default_route |