summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--paramiko/hostkeys.py41
-rw-r--r--tests/test_hostkeys.py3
2 files changed, 34 insertions, 10 deletions
diff --git a/paramiko/hostkeys.py b/paramiko/hostkeys.py
index 0a9c10bb..316acc6c 100644
--- a/paramiko/hostkeys.py
+++ b/paramiko/hostkeys.py
@@ -188,18 +188,43 @@ class HostKeys (UserDict.DictMixin):
@return: keys associated with this host (or C{None})
@rtype: dict(str, L{PKey})
"""
- ret = {}
- valid = False
+ class SubDict (UserDict.DictMixin):
+ def __init__(self, hostname, entries, hostkeys):
+ self._hostname = hostname
+ self._entries = entries
+ self._hostkeys = hostkeys
+
+ def __getitem__(self, key):
+ for e in self._entries:
+ if e.key.get_name() == key:
+ return e.key
+ raise KeyError(key)
+
+ def __setitem__(self, key, val):
+ for e in self._entries:
+ if e.key is None:
+ continue
+ if e.key.get_name() == key:
+ # replace
+ e.key = val
+ break
+ else:
+ # add a new one
+ e = HostKeyEntry([hostname], val)
+ self._entries.append(e)
+ self._hostkeys._entries.append(e)
+
+ def keys(self):
+ return [e.key.get_name() for e in self._entries if e.key is not None]
+
+ entries = []
for e in self._entries:
for h in e.hostnames:
if (h.startswith('|1|') and (self.hash_host(hostname, h) == h)) or (h == hostname):
- valid = True
- if e.key is None:
- continue
- ret[e.key.get_name()] = e.key
- if not valid:
+ entries.append(e)
+ if len(entries) == 0:
return None
- return ret
+ return SubDict(hostname, entries, self)
def check(self, hostname, key):
"""
diff --git a/tests/test_hostkeys.py b/tests/test_hostkeys.py
index e9580ddf..24303570 100644
--- a/tests/test_hostkeys.py
+++ b/tests/test_hostkeys.py
@@ -105,13 +105,12 @@ class HostKeysTest (unittest.TestCase):
'ssh-dss': key_dss
}
hostdict['fake.example.com'] = {}
- # this line will have no effect, but at least shouldn't crash:
hostdict['fake.example.com']['ssh-rsa'] = key
self.assertEquals(3, len(hostdict))
self.assertEquals(2, len(hostdict.values()[0]))
self.assertEquals(1, len(hostdict.values()[1]))
- self.assertEquals(0, len(hostdict.values()[2]))
+ self.assertEquals(1, len(hostdict.values()[2]))
fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp)
fp = hexlify(hostdict['secure.example.com']['ssh-dss'].get_fingerprint()).upper()