summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--paramiko/hostkeys.py33
-rw-r--r--tests/test_hostkeys.py11
2 files changed, 28 insertions, 16 deletions
diff --git a/paramiko/hostkeys.py b/paramiko/hostkeys.py
index 5ef27160..fae25c22 100644
--- a/paramiko/hostkeys.py
+++ b/paramiko/hostkeys.py
@@ -22,13 +22,14 @@ L{HostKeys}
import base64
from Crypto.Hash import SHA, HMAC
+import UserDict
from paramiko.common import *
from paramiko.dsskey import DSSKey
from paramiko.rsakey import RSAKey
-class HostKeys (object):
+class HostKeys (UserDict.DictMixin):
"""
Representation of an openssh-style "known hosts" file. Host keys can be
read from one or more files, and then individual hosts can be looked up to
@@ -49,7 +50,7 @@ class HostKeys (object):
@type filename: str
"""
# hostname -> keytype -> PKey
- self.keys = {}
+ self._keys = {}
self.contains_hashes = False
if filename is not None:
self.load(filename)
@@ -66,11 +67,11 @@ class HostKeys (object):
@param key: the key to add
@type key: L{PKey}
"""
- if not hostname in self.keys:
- self.keys[hostname] = {}
+ if not hostname in self._keys:
+ self._keys[hostname] = {}
if hostname.startswith('|1|'):
self.contains_hashes = True
- self.keys[hostname][keytype] = key
+ self._keys[hostname][keytype] = key
def load(self, filename):
"""
@@ -110,15 +111,15 @@ class HostKeys (object):
@return: keys associated with this host (or C{None})
@rtype: dict(str, L{PKey})
"""
- if hostname in self.keys:
- return self.keys[hostname]
+ if hostname in self._keys:
+ return self._keys[hostname]
if not self.contains_hashes:
return None
- for h in self.keys.keys():
+ for h in self._keys.keys():
if h.startswith('|1|'):
hmac = self.hash_host(hostname, h)
if hmac == h:
- return self.keys[h]
+ return self._keys[h]
return None
def check(self, hostname, key):
@@ -146,21 +147,21 @@ class HostKeys (object):
"""
Remove all host keys from the dictionary.
"""
- self.keys = {}
+ self._keys = {}
self.contains_hashes = False
- def values(self):
- return self.keys.values();
-
def __getitem__(self, key):
ret = self.lookup(key)
if ret is None:
raise KeyError(key)
return ret
- def __len__(self):
- return len(self.keys)
-
+ def keys(self):
+ return self._keys.keys()
+
+ def values(self):
+ return self._keys.values();
+
def hash_host(hostname, salt=None):
"""
Return a "hashed" form of the hostname, as used by openssh when storing
diff --git a/tests/test_hostkeys.py b/tests/test_hostkeys.py
index 13426387..6f8eb57e 100644
--- a/tests/test_hostkeys.py
+++ b/tests/test_hostkeys.py
@@ -71,3 +71,14 @@ class HostKeysTest (unittest.TestCase):
fp = paramiko.util.hexify(x['ssh-rsa'].get_fingerprint())
self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp)
self.assertTrue(hostdict.check('foo.example.com', key))
+
+ def test_3_dict(self):
+ hostdict = paramiko.HostKeys('hostfile.temp')
+ self.assert_('secure.example.com' in hostdict)
+ self.assert_('not.example.com' not in hostdict)
+ self.assert_(hostdict.has_key('secure.example.com'))
+ self.assert_(not hostdict.has_key('not.example.com'))
+ x = hostdict.get('secure.example.com', None)
+ self.assertTrue(x is not None)
+ fp = paramiko.util.hexify(x['ssh-rsa'].get_fingerprint())
+ self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp)