diff options
-rw-r--r-- | dhcpv4/option_domain_search.go | 18 | ||||
-rw-r--r-- | dhcpv4/option_domain_search_test.go | 17 | ||||
-rw-r--r-- | dhcpv6/option_domainsearchlist.go | 16 | ||||
-rw-r--r-- | dhcpv6/option_domainsearchlist_test.go | 16 | ||||
-rw-r--r-- | netboot/netconf.go | 6 | ||||
-rw-r--r-- | rfc1035label/label.go | 83 | ||||
-rw-r--r-- | rfc1035label/label_test.go | 59 |
7 files changed, 144 insertions, 71 deletions
diff --git a/dhcpv4/option_domain_search.go b/dhcpv4/option_domain_search.go index daade0a..c640c0f 100644 --- a/dhcpv4/option_domain_search.go +++ b/dhcpv4/option_domain_search.go @@ -11,7 +11,7 @@ import ( // OptDomainSearch represents an option encapsulating a domain search list. type OptDomainSearch struct { - DomainSearch []string + DomainSearch *rfc1035label.Labels } // Code returns the option code. @@ -22,23 +22,19 @@ func (op *OptDomainSearch) Code() OptionCode { // ToBytes returns a serialized stream of bytes for this option. func (op *OptDomainSearch) ToBytes() []byte { buf := []byte{byte(op.Code()), byte(op.Length())} - buf = append(buf, rfc1035label.LabelsToBytes(op.DomainSearch)...) + buf = append(buf, op.DomainSearch.ToBytes()...) return buf } // Length returns the length of the data portion (excluding option code an byte -// length). +// length). func (op *OptDomainSearch) Length() int { - var length int - for _, label := range op.DomainSearch { - length += len(label) + 2 // add the first and the last length bytes - } - return length + return op.DomainSearch.Length() } // String returns a human-readable string. func (op *OptDomainSearch) String() string { - return fmt.Sprintf("DNS Domain Search List -> %v", op.DomainSearch) + return fmt.Sprintf("DNS Domain Search List -> %v", op.DomainSearch.Labels) } // ParseOptDomainSearch returns a new OptDomainSearch from a byte stream, or @@ -55,9 +51,9 @@ func ParseOptDomainSearch(data []byte) (*OptDomainSearch, error) { if len(data) < 2+length { return nil, ErrShortByteStream } - domainSearch, err := rfc1035label.LabelsFromBytes(data[2:length+2]) + labels, err := rfc1035label.FromBytes(data[2 : length+2]) if err != nil { return nil, err } - return &OptDomainSearch{DomainSearch: domainSearch}, nil + return &OptDomainSearch{DomainSearch: labels}, nil } diff --git a/dhcpv4/option_domain_search_test.go b/dhcpv4/option_domain_search_test.go index 4848a83..590ccd0 100644 --- a/dhcpv4/option_domain_search_test.go +++ b/dhcpv4/option_domain_search_test.go @@ -3,6 +3,7 @@ package dhcpv4 import ( "testing" + "github.com/insomniacslk/dhcp/rfc1035label" "github.com/stretchr/testify/require" ) @@ -15,9 +16,11 @@ func TestParseOptDomainSearch(t *testing.T) { } opt, err := ParseOptDomainSearch(data) require.NoError(t, err) - require.Equal(t, len(opt.DomainSearch), 2) - require.Equal(t, opt.DomainSearch[0], "example.com") - require.Equal(t, opt.DomainSearch[1], "subnet.example.org") + require.Equal(t, 2, len(opt.DomainSearch.Labels)) + require.Equal(t, data[2:], opt.DomainSearch.ToBytes()) + require.Equal(t, len(data[2:]), opt.DomainSearch.Length()) + require.Equal(t, opt.DomainSearch.Labels[0], "example.com") + require.Equal(t, opt.DomainSearch.Labels[1], "subnet.example.org") } func TestOptDomainSearchToBytes(t *testing.T) { @@ -28,9 +31,11 @@ func TestOptDomainSearchToBytes(t *testing.T) { 6, 's', 'u', 'b', 'n', 'e', 't', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'o', 'r', 'g', 0, } opt := OptDomainSearch{ - DomainSearch: []string{ - "example.com", - "subnet.example.org", + DomainSearch: &rfc1035label.Labels{ + Labels: []string{ + "example.com", + "subnet.example.org", + }, }, } require.Equal(t, opt.ToBytes(), expected) diff --git a/dhcpv6/option_domainsearchlist.go b/dhcpv6/option_domainsearchlist.go index a3c6f28..b7a356e 100644 --- a/dhcpv6/option_domainsearchlist.go +++ b/dhcpv6/option_domainsearchlist.go @@ -12,7 +12,7 @@ import ( // OptDomainSearchList list implements a OptionDomainSearchList option type OptDomainSearchList struct { - DomainSearchList []string + DomainSearchList *rfc1035label.Labels } func (op *OptDomainSearchList) Code() OptionCode { @@ -23,30 +23,30 @@ func (op *OptDomainSearchList) ToBytes() []byte { buf := make([]byte, 4) binary.BigEndian.PutUint16(buf[0:2], uint16(OptionDomainSearchList)) binary.BigEndian.PutUint16(buf[2:4], uint16(op.Length())) - buf = append(buf, rfc1035label.LabelsToBytes(op.DomainSearchList)...) + buf = append(buf, op.DomainSearchList.ToBytes()...) return buf } func (op *OptDomainSearchList) Length() int { var length int - for _, label := range op.DomainSearchList { + for _, label := range op.DomainSearchList.Labels { length += len(label) + 2 // add the first and the last length bytes } return length } func (op *OptDomainSearchList) String() string { - return fmt.Sprintf("OptDomainSearchList{searchlist=%v}", op.DomainSearchList) + return fmt.Sprintf("OptDomainSearchList{searchlist=%v}", op.DomainSearchList.Labels) } -// build an OptDomainSearchList structure from a sequence of bytes. -// The input data does not include option code and length bytes. +// ParseOptDomainSearchList builds an OptDomainSearchList structure from a sequence +// of bytes. The input data does not include option code and length bytes. func ParseOptDomainSearchList(data []byte) (*OptDomainSearchList, error) { opt := OptDomainSearchList{} - var err error - opt.DomainSearchList, err = rfc1035label.LabelsFromBytes(data) + labels, err := rfc1035label.FromBytes(data) if err != nil { return nil, err } + opt.DomainSearchList = labels return &opt, nil } diff --git a/dhcpv6/option_domainsearchlist_test.go b/dhcpv6/option_domainsearchlist_test.go index 972a5bc..0b4b6b0 100644 --- a/dhcpv6/option_domainsearchlist_test.go +++ b/dhcpv6/option_domainsearchlist_test.go @@ -3,6 +3,7 @@ package dhcpv6 import ( "testing" + "github.com/insomniacslk/dhcp/rfc1035label" "github.com/stretchr/testify/require" ) @@ -14,9 +15,10 @@ func TestParseOptDomainSearchList(t *testing.T) { opt, err := ParseOptDomainSearchList(data) require.NoError(t, err) require.Equal(t, OptionDomainSearchList, opt.Code()) - require.Equal(t, 2, len(opt.DomainSearchList)) - require.Equal(t, "example.com", opt.DomainSearchList[0]) - require.Equal(t, "subnet.example.org", opt.DomainSearchList[1]) + require.Equal(t, 2, len(opt.DomainSearchList.Labels)) + require.Equal(t, len(data), opt.DomainSearchList.Length()) + require.Equal(t, "example.com", opt.DomainSearchList.Labels[0]) + require.Equal(t, "subnet.example.org", opt.DomainSearchList.Labels[1]) require.Contains(t, opt.String(), "searchlist=[example.com subnet.example.org]", "String() should contain the correct domain search output") } @@ -28,9 +30,11 @@ func TestOptDomainSearchListToBytes(t *testing.T) { 6, 's', 'u', 'b', 'n', 'e', 't', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'o', 'r', 'g', 0, } opt := OptDomainSearchList{ - DomainSearchList: []string{ - "example.com", - "subnet.example.org", + DomainSearchList: &rfc1035label.Labels{ + Labels: []string{ + "example.com", + "subnet.example.org", + }, }, } require.Equal(t, expected, opt.ToBytes()) diff --git a/netboot/netconf.go b/netboot/netconf.go index ec6a5e1..aef4efa 100644 --- a/netboot/netconf.go +++ b/netboot/netconf.go @@ -69,7 +69,7 @@ func GetNetConfFromPacketv6(d *dhcpv6.DHCPv6Message) (*NetConf, error) { if opt != nil { odomains := opt.(*dhcpv6.OptDomainSearchList) // TODO should this be copied? - netconf.DNSSearchList = odomains.DomainSearchList + netconf.DNSSearchList = odomains.DomainSearchList.Labels } return &netconf, nil @@ -134,10 +134,10 @@ func GetNetConfFromPacketv4(d *dhcpv4.DHCPv4) (*NetConf, error) { dnsDomainSearchListOption := d.GetOneOption(dhcpv4.OptionDNSDomainSearchList) if dnsDomainSearchListOption != nil { dnsSearchList := dnsDomainSearchListOption.(*dhcpv4.OptDomainSearch).DomainSearch - if len(dnsSearchList) == 0 { + if len(dnsSearchList.Labels) == 0 { return nil, errors.New("dns search list is empty") } - netconf.DNSSearchList = dnsSearchList + netconf.DNSSearchList = dnsSearchList.Labels } // get default gateway diff --git a/rfc1035label/label.go b/rfc1035label/label.go index 26d8a49..5093de8 100644 --- a/rfc1035label/label.go +++ b/rfc1035label/label.go @@ -8,11 +8,78 @@ import ( // This implements RFC 1035 labels, including compression. // https://tools.ietf.org/html/rfc1035#section-4.1.4 -// LabelsFromBytes decodes a serialized stream and returns a list of labels -func LabelsFromBytes(buf []byte) ([]string, error) { +// Labels represents RFC1035 labels +type Labels struct { + // original contains the original bytes if the object was parsed from a byte + // sequence, or nil otherwise. The `original` field is necessary to deal + // with compressed labels. If the labels are further modified, the original + // content is invalidated and no compression will be used. + original []byte + // Labels contains the parsed labels. A change here invalidates the + // `original` object. + Labels []string +} + +// same compares two string arrays +func same(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if a[i] != b[i] { + return false + } + } + return true +} + +// ToBytes returns a byte sequence representing the labels. If the original +// sequence is modified, the labels are parsed again, otherwise the original +// byte sequence is returned. +func (l *Labels) ToBytes() []byte { + // if the original byte sequence has been modified, invalidate it and + // serialize again. + // NOTE: this function is not thread-safe. If multiple threads modify + // the `Labels` field, the result may be wrong. + originalLabels, err := labelsFromBytes(l.original) + // if the original object has not been modified, or we cannot parse it, + // return the original bytes. + if err != nil || (l.original != nil && same(originalLabels, l.Labels)) { + return l.original + } + return labelsToBytes(l.Labels) +} + +// Length returns the length in bytes of the serialized labels +func (l *Labels) Length() int { + return len(l.ToBytes()) +} + +// NewLabels returns an initialized Labels object. +func NewLabels() *Labels { + return &Labels{ + Labels: make([]string, 0), + } +} + +// FromBytes returns a Labels object from the given byte sequence, or an error if +// any. +func FromBytes(data []byte) (*Labels, error) { + lab := NewLabels() + l, err := labelsFromBytes(data) + if err != nil { + return nil, err + } + lab.original = data + lab.Labels = l + return lab, nil +} + +// fromBytes decodes a serialized stream and returns a list of labels +func labelsFromBytes(buf []byte) ([]string, error) { var ( - pos, oldPos int labels = make([]string, 0) + pos, oldPos int label string handlingPointer bool ) @@ -58,8 +125,8 @@ func LabelsFromBytes(buf []byte) ([]string, error) { return labels, nil } -// LabelToBytes encodes a label and returns a serialized stream of bytes -func LabelToBytes(label string) []byte { +// labelToBytes encodes a label and returns a serialized stream of bytes +func labelToBytes(label string) []byte { var encodedLabel []byte if len(label) == 0 { return []byte{0} @@ -71,12 +138,12 @@ func LabelToBytes(label string) []byte { return append(encodedLabel, 0) } -// LabelsToBytes encodes a list of labels and returns a serialized stream of +// labelsToBytes encodes a list of labels and returns a serialized stream of // bytes -func LabelsToBytes(labels []string) []byte { +func labelsToBytes(labels []string) []byte { var encodedLabels []byte for _, label := range labels { - encodedLabels = append(encodedLabels, LabelToBytes(label)...) + encodedLabels = append(encodedLabels, labelToBytes(label)...) } return encodedLabels } diff --git a/rfc1035label/label_test.go b/rfc1035label/label_test.go index 3a69f2b..6098e44 100644 --- a/rfc1035label/label_test.go +++ b/rfc1035label/label_test.go @@ -7,42 +7,35 @@ import ( ) func TestLabelsFromBytes(t *testing.T) { - labels, err := LabelsFromBytes([]byte{ + expected := []byte{ 0x9, 's', 'l', 'a', 'c', 'k', 'w', 'a', 'r', 'e', 0x2, 'i', 't', 0x0, - }) + } + labels, err := FromBytes(expected) require.NoError(t, err) - require.Equal(t, 1, len(labels)) - require.Equal(t, "slackware.it", labels[0]) + require.Equal(t, 1, len(labels.Labels)) + require.Equal(t, len(expected), labels.Length()) + require.Equal(t, expected, labels.ToBytes()) + require.Equal(t, "slackware.it", labels.Labels[0]) } func TestLabelsFromBytesZeroLength(t *testing.T) { - labels, err := LabelsFromBytes([]byte{}) + labels, err := FromBytes([]byte{}) require.NoError(t, err) - require.Equal(t, 0, len(labels)) + require.Equal(t, 0, len(labels.Labels)) + require.Equal(t, 0, labels.Length()) + require.Equal(t, []byte{}, labels.ToBytes()) } func TestLabelsFromBytesInvalidLength(t *testing.T) { - labels, err := LabelsFromBytes([]byte{0x5, 0xaa, 0xbb}) // short length + _, err := FromBytes([]byte{0x5, 0xaa, 0xbb}) // short length require.Error(t, err) - require.Equal(t, 0, len(labels)) } func TestLabelsFromBytesInvalidLengthOffByOne(t *testing.T) { - labels, err := LabelsFromBytes([]byte{0x3, 0xaa, 0xbb}) // short length + _, err := FromBytes([]byte{0x3, 0xaa, 0xbb}) // short length require.Error(t, err) - require.Equal(t, 0, len(labels)) -} - -func TestLabelToBytes(t *testing.T) { - encodedLabel := LabelToBytes("slackware.it") - expected := []byte{ - 0x9, 's', 'l', 'a', 'c', 'k', 'w', 'a', 'r', 'e', - 0x2, 'i', 't', - 0x0, - } - require.Equal(t, expected, encodedLabel) } func TestLabelsToBytes(t *testing.T) { @@ -55,14 +48,20 @@ func TestLabelsToBytes(t *testing.T) { 2, 'i', 't', 0, } - encodedLabels := LabelsToBytes([]string{"slackware.it", "insomniac.slackware.it"}) - require.Equal(t, expected, encodedLabels) + labels := Labels{ + Labels: []string{ + "slackware.it", + "insomniac.slackware.it", + }, + } + require.Equal(t, expected, labels.ToBytes()) } func TestLabelToBytesZeroLength(t *testing.T) { - encodedLabel := LabelToBytes("") - expected := []byte{0} - require.Equal(t, expected, encodedLabel) + labels := Labels{ + Labels: []string{""}, + } + require.Equal(t, []byte{0}, labels.ToBytes()) } func TestCompressedLabel(t *testing.T) { @@ -89,9 +88,11 @@ func TestCompressedLabel(t *testing.T) { "systemboot.org", } - labels, err := LabelsFromBytes(data) + labels, err := FromBytes(data) require.NoError(t, err) - require.Equal(t, expected, labels) + require.Equal(t, 4, len(labels.Labels)) + require.Equal(t, expected, labels.Labels) + require.Equal(t, len(data), labels.Length()) } func TestShortCompressedLabel(t *testing.T) { @@ -105,7 +106,7 @@ func TestShortCompressedLabel(t *testing.T) { 192, } - _, err := LabelsFromBytes(data) + _, err := FromBytes(data) require.Error(t, err) } @@ -121,6 +122,6 @@ func TestNestedCompressedLabel(t *testing.T) { 9, 'i', 'n', 's', 'o', 'm', 'n', 'i', 'a', 'c', 192, 5, } - _, err := LabelsFromBytes(data) + _, err := FromBytes(data) require.Error(t, err) } |