summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--dhcpv4/option_domain_search.go18
-rw-r--r--dhcpv4/option_domain_search_test.go17
-rw-r--r--dhcpv6/option_domainsearchlist.go16
-rw-r--r--dhcpv6/option_domainsearchlist_test.go16
-rw-r--r--netboot/netconf.go6
-rw-r--r--rfc1035label/label.go83
-rw-r--r--rfc1035label/label_test.go59
7 files changed, 144 insertions, 71 deletions
diff --git a/dhcpv4/option_domain_search.go b/dhcpv4/option_domain_search.go
index daade0a..c640c0f 100644
--- a/dhcpv4/option_domain_search.go
+++ b/dhcpv4/option_domain_search.go
@@ -11,7 +11,7 @@ import (
// OptDomainSearch represents an option encapsulating a domain search list.
type OptDomainSearch struct {
- DomainSearch []string
+ DomainSearch *rfc1035label.Labels
}
// Code returns the option code.
@@ -22,23 +22,19 @@ func (op *OptDomainSearch) Code() OptionCode {
// ToBytes returns a serialized stream of bytes for this option.
func (op *OptDomainSearch) ToBytes() []byte {
buf := []byte{byte(op.Code()), byte(op.Length())}
- buf = append(buf, rfc1035label.LabelsToBytes(op.DomainSearch)...)
+ buf = append(buf, op.DomainSearch.ToBytes()...)
return buf
}
// Length returns the length of the data portion (excluding option code an byte
-// length).
+// length).
func (op *OptDomainSearch) Length() int {
- var length int
- for _, label := range op.DomainSearch {
- length += len(label) + 2 // add the first and the last length bytes
- }
- return length
+ return op.DomainSearch.Length()
}
// String returns a human-readable string.
func (op *OptDomainSearch) String() string {
- return fmt.Sprintf("DNS Domain Search List -> %v", op.DomainSearch)
+ return fmt.Sprintf("DNS Domain Search List -> %v", op.DomainSearch.Labels)
}
// ParseOptDomainSearch returns a new OptDomainSearch from a byte stream, or
@@ -55,9 +51,9 @@ func ParseOptDomainSearch(data []byte) (*OptDomainSearch, error) {
if len(data) < 2+length {
return nil, ErrShortByteStream
}
- domainSearch, err := rfc1035label.LabelsFromBytes(data[2:length+2])
+ labels, err := rfc1035label.FromBytes(data[2 : length+2])
if err != nil {
return nil, err
}
- return &OptDomainSearch{DomainSearch: domainSearch}, nil
+ return &OptDomainSearch{DomainSearch: labels}, nil
}
diff --git a/dhcpv4/option_domain_search_test.go b/dhcpv4/option_domain_search_test.go
index 4848a83..590ccd0 100644
--- a/dhcpv4/option_domain_search_test.go
+++ b/dhcpv4/option_domain_search_test.go
@@ -3,6 +3,7 @@ package dhcpv4
import (
"testing"
+ "github.com/insomniacslk/dhcp/rfc1035label"
"github.com/stretchr/testify/require"
)
@@ -15,9 +16,11 @@ func TestParseOptDomainSearch(t *testing.T) {
}
opt, err := ParseOptDomainSearch(data)
require.NoError(t, err)
- require.Equal(t, len(opt.DomainSearch), 2)
- require.Equal(t, opt.DomainSearch[0], "example.com")
- require.Equal(t, opt.DomainSearch[1], "subnet.example.org")
+ require.Equal(t, 2, len(opt.DomainSearch.Labels))
+ require.Equal(t, data[2:], opt.DomainSearch.ToBytes())
+ require.Equal(t, len(data[2:]), opt.DomainSearch.Length())
+ require.Equal(t, opt.DomainSearch.Labels[0], "example.com")
+ require.Equal(t, opt.DomainSearch.Labels[1], "subnet.example.org")
}
func TestOptDomainSearchToBytes(t *testing.T) {
@@ -28,9 +31,11 @@ func TestOptDomainSearchToBytes(t *testing.T) {
6, 's', 'u', 'b', 'n', 'e', 't', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'o', 'r', 'g', 0,
}
opt := OptDomainSearch{
- DomainSearch: []string{
- "example.com",
- "subnet.example.org",
+ DomainSearch: &rfc1035label.Labels{
+ Labels: []string{
+ "example.com",
+ "subnet.example.org",
+ },
},
}
require.Equal(t, opt.ToBytes(), expected)
diff --git a/dhcpv6/option_domainsearchlist.go b/dhcpv6/option_domainsearchlist.go
index a3c6f28..b7a356e 100644
--- a/dhcpv6/option_domainsearchlist.go
+++ b/dhcpv6/option_domainsearchlist.go
@@ -12,7 +12,7 @@ import (
// OptDomainSearchList list implements a OptionDomainSearchList option
type OptDomainSearchList struct {
- DomainSearchList []string
+ DomainSearchList *rfc1035label.Labels
}
func (op *OptDomainSearchList) Code() OptionCode {
@@ -23,30 +23,30 @@ func (op *OptDomainSearchList) ToBytes() []byte {
buf := make([]byte, 4)
binary.BigEndian.PutUint16(buf[0:2], uint16(OptionDomainSearchList))
binary.BigEndian.PutUint16(buf[2:4], uint16(op.Length()))
- buf = append(buf, rfc1035label.LabelsToBytes(op.DomainSearchList)...)
+ buf = append(buf, op.DomainSearchList.ToBytes()...)
return buf
}
func (op *OptDomainSearchList) Length() int {
var length int
- for _, label := range op.DomainSearchList {
+ for _, label := range op.DomainSearchList.Labels {
length += len(label) + 2 // add the first and the last length bytes
}
return length
}
func (op *OptDomainSearchList) String() string {
- return fmt.Sprintf("OptDomainSearchList{searchlist=%v}", op.DomainSearchList)
+ return fmt.Sprintf("OptDomainSearchList{searchlist=%v}", op.DomainSearchList.Labels)
}
-// build an OptDomainSearchList structure from a sequence of bytes.
-// The input data does not include option code and length bytes.
+// ParseOptDomainSearchList builds an OptDomainSearchList structure from a sequence
+// of bytes. The input data does not include option code and length bytes.
func ParseOptDomainSearchList(data []byte) (*OptDomainSearchList, error) {
opt := OptDomainSearchList{}
- var err error
- opt.DomainSearchList, err = rfc1035label.LabelsFromBytes(data)
+ labels, err := rfc1035label.FromBytes(data)
if err != nil {
return nil, err
}
+ opt.DomainSearchList = labels
return &opt, nil
}
diff --git a/dhcpv6/option_domainsearchlist_test.go b/dhcpv6/option_domainsearchlist_test.go
index 972a5bc..0b4b6b0 100644
--- a/dhcpv6/option_domainsearchlist_test.go
+++ b/dhcpv6/option_domainsearchlist_test.go
@@ -3,6 +3,7 @@ package dhcpv6
import (
"testing"
+ "github.com/insomniacslk/dhcp/rfc1035label"
"github.com/stretchr/testify/require"
)
@@ -14,9 +15,10 @@ func TestParseOptDomainSearchList(t *testing.T) {
opt, err := ParseOptDomainSearchList(data)
require.NoError(t, err)
require.Equal(t, OptionDomainSearchList, opt.Code())
- require.Equal(t, 2, len(opt.DomainSearchList))
- require.Equal(t, "example.com", opt.DomainSearchList[0])
- require.Equal(t, "subnet.example.org", opt.DomainSearchList[1])
+ require.Equal(t, 2, len(opt.DomainSearchList.Labels))
+ require.Equal(t, len(data), opt.DomainSearchList.Length())
+ require.Equal(t, "example.com", opt.DomainSearchList.Labels[0])
+ require.Equal(t, "subnet.example.org", opt.DomainSearchList.Labels[1])
require.Contains(t, opt.String(), "searchlist=[example.com subnet.example.org]", "String() should contain the correct domain search output")
}
@@ -28,9 +30,11 @@ func TestOptDomainSearchListToBytes(t *testing.T) {
6, 's', 'u', 'b', 'n', 'e', 't', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'o', 'r', 'g', 0,
}
opt := OptDomainSearchList{
- DomainSearchList: []string{
- "example.com",
- "subnet.example.org",
+ DomainSearchList: &rfc1035label.Labels{
+ Labels: []string{
+ "example.com",
+ "subnet.example.org",
+ },
},
}
require.Equal(t, expected, opt.ToBytes())
diff --git a/netboot/netconf.go b/netboot/netconf.go
index ec6a5e1..aef4efa 100644
--- a/netboot/netconf.go
+++ b/netboot/netconf.go
@@ -69,7 +69,7 @@ func GetNetConfFromPacketv6(d *dhcpv6.DHCPv6Message) (*NetConf, error) {
if opt != nil {
odomains := opt.(*dhcpv6.OptDomainSearchList)
// TODO should this be copied?
- netconf.DNSSearchList = odomains.DomainSearchList
+ netconf.DNSSearchList = odomains.DomainSearchList.Labels
}
return &netconf, nil
@@ -134,10 +134,10 @@ func GetNetConfFromPacketv4(d *dhcpv4.DHCPv4) (*NetConf, error) {
dnsDomainSearchListOption := d.GetOneOption(dhcpv4.OptionDNSDomainSearchList)
if dnsDomainSearchListOption != nil {
dnsSearchList := dnsDomainSearchListOption.(*dhcpv4.OptDomainSearch).DomainSearch
- if len(dnsSearchList) == 0 {
+ if len(dnsSearchList.Labels) == 0 {
return nil, errors.New("dns search list is empty")
}
- netconf.DNSSearchList = dnsSearchList
+ netconf.DNSSearchList = dnsSearchList.Labels
}
// get default gateway
diff --git a/rfc1035label/label.go b/rfc1035label/label.go
index 26d8a49..5093de8 100644
--- a/rfc1035label/label.go
+++ b/rfc1035label/label.go
@@ -8,11 +8,78 @@ import (
// This implements RFC 1035 labels, including compression.
// https://tools.ietf.org/html/rfc1035#section-4.1.4
-// LabelsFromBytes decodes a serialized stream and returns a list of labels
-func LabelsFromBytes(buf []byte) ([]string, error) {
+// Labels represents RFC1035 labels
+type Labels struct {
+ // original contains the original bytes if the object was parsed from a byte
+ // sequence, or nil otherwise. The `original` field is necessary to deal
+ // with compressed labels. If the labels are further modified, the original
+ // content is invalidated and no compression will be used.
+ original []byte
+ // Labels contains the parsed labels. A change here invalidates the
+ // `original` object.
+ Labels []string
+}
+
+// same compares two string arrays
+func same(a, b []string) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for i := 0; i < len(a); i++ {
+ if a[i] != b[i] {
+ return false
+ }
+ }
+ return true
+}
+
+// ToBytes returns a byte sequence representing the labels. If the original
+// sequence is modified, the labels are parsed again, otherwise the original
+// byte sequence is returned.
+func (l *Labels) ToBytes() []byte {
+ // if the original byte sequence has been modified, invalidate it and
+ // serialize again.
+ // NOTE: this function is not thread-safe. If multiple threads modify
+ // the `Labels` field, the result may be wrong.
+ originalLabels, err := labelsFromBytes(l.original)
+ // if the original object has not been modified, or we cannot parse it,
+ // return the original bytes.
+ if err != nil || (l.original != nil && same(originalLabels, l.Labels)) {
+ return l.original
+ }
+ return labelsToBytes(l.Labels)
+}
+
+// Length returns the length in bytes of the serialized labels
+func (l *Labels) Length() int {
+ return len(l.ToBytes())
+}
+
+// NewLabels returns an initialized Labels object.
+func NewLabels() *Labels {
+ return &Labels{
+ Labels: make([]string, 0),
+ }
+}
+
+// FromBytes returns a Labels object from the given byte sequence, or an error if
+// any.
+func FromBytes(data []byte) (*Labels, error) {
+ lab := NewLabels()
+ l, err := labelsFromBytes(data)
+ if err != nil {
+ return nil, err
+ }
+ lab.original = data
+ lab.Labels = l
+ return lab, nil
+}
+
+// fromBytes decodes a serialized stream and returns a list of labels
+func labelsFromBytes(buf []byte) ([]string, error) {
var (
- pos, oldPos int
labels = make([]string, 0)
+ pos, oldPos int
label string
handlingPointer bool
)
@@ -58,8 +125,8 @@ func LabelsFromBytes(buf []byte) ([]string, error) {
return labels, nil
}
-// LabelToBytes encodes a label and returns a serialized stream of bytes
-func LabelToBytes(label string) []byte {
+// labelToBytes encodes a label and returns a serialized stream of bytes
+func labelToBytes(label string) []byte {
var encodedLabel []byte
if len(label) == 0 {
return []byte{0}
@@ -71,12 +138,12 @@ func LabelToBytes(label string) []byte {
return append(encodedLabel, 0)
}
-// LabelsToBytes encodes a list of labels and returns a serialized stream of
+// labelsToBytes encodes a list of labels and returns a serialized stream of
// bytes
-func LabelsToBytes(labels []string) []byte {
+func labelsToBytes(labels []string) []byte {
var encodedLabels []byte
for _, label := range labels {
- encodedLabels = append(encodedLabels, LabelToBytes(label)...)
+ encodedLabels = append(encodedLabels, labelToBytes(label)...)
}
return encodedLabels
}
diff --git a/rfc1035label/label_test.go b/rfc1035label/label_test.go
index 3a69f2b..6098e44 100644
--- a/rfc1035label/label_test.go
+++ b/rfc1035label/label_test.go
@@ -7,42 +7,35 @@ import (
)
func TestLabelsFromBytes(t *testing.T) {
- labels, err := LabelsFromBytes([]byte{
+ expected := []byte{
0x9, 's', 'l', 'a', 'c', 'k', 'w', 'a', 'r', 'e',
0x2, 'i', 't',
0x0,
- })
+ }
+ labels, err := FromBytes(expected)
require.NoError(t, err)
- require.Equal(t, 1, len(labels))
- require.Equal(t, "slackware.it", labels[0])
+ require.Equal(t, 1, len(labels.Labels))
+ require.Equal(t, len(expected), labels.Length())
+ require.Equal(t, expected, labels.ToBytes())
+ require.Equal(t, "slackware.it", labels.Labels[0])
}
func TestLabelsFromBytesZeroLength(t *testing.T) {
- labels, err := LabelsFromBytes([]byte{})
+ labels, err := FromBytes([]byte{})
require.NoError(t, err)
- require.Equal(t, 0, len(labels))
+ require.Equal(t, 0, len(labels.Labels))
+ require.Equal(t, 0, labels.Length())
+ require.Equal(t, []byte{}, labels.ToBytes())
}
func TestLabelsFromBytesInvalidLength(t *testing.T) {
- labels, err := LabelsFromBytes([]byte{0x5, 0xaa, 0xbb}) // short length
+ _, err := FromBytes([]byte{0x5, 0xaa, 0xbb}) // short length
require.Error(t, err)
- require.Equal(t, 0, len(labels))
}
func TestLabelsFromBytesInvalidLengthOffByOne(t *testing.T) {
- labels, err := LabelsFromBytes([]byte{0x3, 0xaa, 0xbb}) // short length
+ _, err := FromBytes([]byte{0x3, 0xaa, 0xbb}) // short length
require.Error(t, err)
- require.Equal(t, 0, len(labels))
-}
-
-func TestLabelToBytes(t *testing.T) {
- encodedLabel := LabelToBytes("slackware.it")
- expected := []byte{
- 0x9, 's', 'l', 'a', 'c', 'k', 'w', 'a', 'r', 'e',
- 0x2, 'i', 't',
- 0x0,
- }
- require.Equal(t, expected, encodedLabel)
}
func TestLabelsToBytes(t *testing.T) {
@@ -55,14 +48,20 @@ func TestLabelsToBytes(t *testing.T) {
2, 'i', 't',
0,
}
- encodedLabels := LabelsToBytes([]string{"slackware.it", "insomniac.slackware.it"})
- require.Equal(t, expected, encodedLabels)
+ labels := Labels{
+ Labels: []string{
+ "slackware.it",
+ "insomniac.slackware.it",
+ },
+ }
+ require.Equal(t, expected, labels.ToBytes())
}
func TestLabelToBytesZeroLength(t *testing.T) {
- encodedLabel := LabelToBytes("")
- expected := []byte{0}
- require.Equal(t, expected, encodedLabel)
+ labels := Labels{
+ Labels: []string{""},
+ }
+ require.Equal(t, []byte{0}, labels.ToBytes())
}
func TestCompressedLabel(t *testing.T) {
@@ -89,9 +88,11 @@ func TestCompressedLabel(t *testing.T) {
"systemboot.org",
}
- labels, err := LabelsFromBytes(data)
+ labels, err := FromBytes(data)
require.NoError(t, err)
- require.Equal(t, expected, labels)
+ require.Equal(t, 4, len(labels.Labels))
+ require.Equal(t, expected, labels.Labels)
+ require.Equal(t, len(data), labels.Length())
}
func TestShortCompressedLabel(t *testing.T) {
@@ -105,7 +106,7 @@ func TestShortCompressedLabel(t *testing.T) {
192,
}
- _, err := LabelsFromBytes(data)
+ _, err := FromBytes(data)
require.Error(t, err)
}
@@ -121,6 +122,6 @@ func TestNestedCompressedLabel(t *testing.T) {
9, 'i', 'n', 's', 'o', 'm', 'n', 'i', 'a', 'c',
192, 5,
}
- _, err := LabelsFromBytes(data)
+ _, err := FromBytes(data)
require.Error(t, err)
}