summaryrefslogtreecommitdiffhomepage
path: root/message.py
blob: 0660fe63c5d1ca14a4566df381ab1c23dfaf574e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# implementation of an SSH2 "message"

import string, types, struct
from util import inflate_long, deflate_long


class Message(object):
    "represents the encoding of an SSH2 message"

    def __init__(self, content=''):
        self.packet = content
        self.idx = 0
        self.seqno = -1

    def __str__(self):
        return self.packet

    def __repr__(self):
        return 'Message(' + repr(self.packet) + ')'
    
    def get_remainder(self):
        "remaining bytes still unparsed"
        return self.packet[self.idx:]

    def get_so_far(self):
        "bytes that have been parsed"
        return self.packet[:self.idx]

    def get_bytes(self, n):
        if self.idx + n > len(self.packet):
            return '\x00'*n
        b = self.packet[self.idx:self.idx+n]
        self.idx = self.idx + n
        return b
    
    def get_byte(self):
        return self.get_bytes(1)

    def get_boolean(self):
        b = self.get_bytes(1)
        if b == '\x00':
            return 0
        else:
            return 1

    def get_int(self):
        x = self.packet
        i = self.idx
        if i + 4 > len(x):
            return 0
        n = struct.unpack('>I', x[i:i+4])[0]
        self.idx = i+4
        return n

    def get_mpint(self):
        return inflate_long(self.get_string())

    def get_string(self):
        l = self.get_int()
        if self.idx + l > len(self.packet):
            return ''
        str = self.packet[self.idx:self.idx+l]
        self.idx = self.idx + l
        return str

    def get_list(self):
        str = self.get_string()
        l = string.split(str, ',')
        return l

    def add_bytes(self, b):
        self.packet = self.packet + b
        return self

    def add_byte(self, b):
        self.packet = self.packet + b
        return self

    def add_boolean(self, b):
        if b:
            self.add_byte('\x01')
        else:
            self.add_byte('\x00')
        return self
            
    def add_int(self, n):
        self.packet = self.packet + struct.pack('>I', n)
        return self

    def add_mpint(self, z):
        "this only works on positive numbers"
        self.add_string(deflate_long(z))
        return self

    def add_string(self, s):
        self.add_int(len(s))
        self.packet = self.packet + s
        return self

    def add_list(self, l):
        out = string.join(l, ',')
        self.add_int(len(out))
        self.packet = self.packet + out
        return self
        
    def add(self, i):
        if type(i) == types.StringType:
            return self.add_string(i)
        elif type(i) == types.IntType:
            return self.add_int(i)
        elif type(i) == types.LongType:
            if i > 0xffffffffL:
                return self.add_mpint(i)
            else:
                return self.add_int(i)
        elif type(i) == types.ListType:
            return self.add_list(i)
        else:
            raise exception('Unknown type')