diff options
-rw-r--r-- | dhcpv4/option_domain_search.go | 55 | ||||
-rw-r--r-- | dhcpv4/option_domain_search_test.go | 37 | ||||
-rw-r--r-- | dhcpv4/option_rfc1035label.go | 53 | ||||
-rw-r--r-- | dhcpv4/option_rfc1035label_test.go | 66 |
4 files changed, 211 insertions, 0 deletions
diff --git a/dhcpv4/option_domain_search.go b/dhcpv4/option_domain_search.go new file mode 100644 index 0000000..c30cbe1 --- /dev/null +++ b/dhcpv4/option_domain_search.go @@ -0,0 +1,55 @@ +package dhcpv4 + +// This module defines the OptDomainSearch structure. +// https://tools.ietf.org/html/rfc3397 + +import ( + "fmt" +) + +// OptDomainSearch represents an option encapsulating a domain search list. +type OptDomainSearch struct { + DomainSearch []string +} + +func (op *OptDomainSearch) Code() OptionCode { + return OptionDNSDomainSearchList +} + +func (op *OptDomainSearch) ToBytes() []byte { + buf := []byte{byte(op.Code()), byte(op.Length())} + buf = append(buf, LabelsToBytes(op.DomainSearch)...) + return buf +} + +func (op *OptDomainSearch) Length() int { + var length int + for _, label := range op.DomainSearch { + length += len(label) + 2 // add the first and the last length bytes + } + return length +} + +func (op *OptDomainSearch) String() string { + return fmt.Sprintf("DNS Domain Search List ->", op.DomainSearch) +} + +// build an OptDomainSearch structure from a sequence of bytes. +func ParseOptDomainSearch(data []byte) (*OptDomainSearch, error) { + if len(data) < 2 { + return nil, ErrShortByteStream + } + code := OptionCode(data[0]) + if code != OptionDNSDomainSearchList { + return nil, fmt.Errorf("expected code %v, got %v", OptionDNSDomainSearchList, code) + } + length := int(data[1]) + if len(data) < 2+length { + return nil, ErrShortByteStream + } + domainSearch, err := LabelsFromBytes(data[2:length+2]) + if err != nil { + return nil, err + } + return &OptDomainSearch{DomainSearch: domainSearch}, nil +} diff --git a/dhcpv4/option_domain_search_test.go b/dhcpv4/option_domain_search_test.go new file mode 100644 index 0000000..4848a83 --- /dev/null +++ b/dhcpv4/option_domain_search_test.go @@ -0,0 +1,37 @@ +package dhcpv4 + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseOptDomainSearch(t *testing.T) { + data := []byte{ + 119, // OptionDNSDomainSearchList + 33, // length + 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, + 6, 's', 'u', 'b', 'n', 'e', 't', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'o', 'r', 'g', 0, + } + opt, err := ParseOptDomainSearch(data) + require.NoError(t, err) + require.Equal(t, len(opt.DomainSearch), 2) + require.Equal(t, opt.DomainSearch[0], "example.com") + require.Equal(t, opt.DomainSearch[1], "subnet.example.org") +} + +func TestOptDomainSearchToBytes(t *testing.T) { + expected := []byte{ + 119, // OptionDNSDomainSearchList + 33, // length + 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, + 6, 's', 'u', 'b', 'n', 'e', 't', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'o', 'r', 'g', 0, + } + opt := OptDomainSearch{ + DomainSearch: []string{ + "example.com", + "subnet.example.org", + }, + } + require.Equal(t, opt.ToBytes(), expected) +} diff --git a/dhcpv4/option_rfc1035label.go b/dhcpv4/option_rfc1035label.go new file mode 100644 index 0000000..b78d8da --- /dev/null +++ b/dhcpv4/option_rfc1035label.go @@ -0,0 +1,53 @@ +package dhcpv4 + +import ( + "fmt" + "strings" +) + +func LabelsFromBytes(buf []byte) ([]string, error) { + var ( + pos = 0 + domains = make([]string, 0) + label = "" + ) + for { + if pos >= len(buf) { + return domains, nil + } + length := int(buf[pos]) + pos++ + if length == 0 { + domains = append(domains, label) + label = "" + } + if len(buf)-pos < length { + return nil, fmt.Errorf("DomainNamesFromBytes: invalid short label length") + } + if label != "" { + label += "." + } + label += string(buf[pos : pos+length]) + pos += length + } +} + +func LabelToBytes(label string) []byte { + var encodedLabel []byte + if len(label) == 0 { + return []byte{0} + } + for _, part := range strings.Split(label, ".") { + encodedLabel = append(encodedLabel, byte(len(part))) + encodedLabel = append(encodedLabel, []byte(part)...) + } + return append(encodedLabel, 0) +} + +func LabelsToBytes(labels []string) []byte { + var encodedLabels []byte + for _, label := range labels { + encodedLabels = append(encodedLabels, LabelToBytes(label)...) + } + return encodedLabels +} diff --git a/dhcpv4/option_rfc1035label_test.go b/dhcpv4/option_rfc1035label_test.go new file mode 100644 index 0000000..30c87c8 --- /dev/null +++ b/dhcpv4/option_rfc1035label_test.go @@ -0,0 +1,66 @@ +package dhcpv4 + +import ( + "bytes" + "testing" +) + +func TestLabelsFromBytes(t *testing.T) { + labels, err := LabelsFromBytes([]byte{ + 0x9, 's', 'l', 'a', 'c', 'k', 'w', 'a', 'r', 'e', + 0x2, 'i', 't', + 0x0, + }) + if err != nil { + t.Fatal(err) + } + if len(labels) != 1 { + t.Fatalf("Invalid labels length. Expected: 1, got: %v", len(labels)) + } + if labels[0] != "slackware.it" { + t.Fatalf("Invalid label. Expected: %v, got: %v'", "slackware.it", labels[0]) + } +} + +func TestLabelsFromBytesZeroLength(t *testing.T) { + labels, err := LabelsFromBytes([]byte{}) + if err != nil { + t.Fatal(err) + } + if len(labels) != 0 { + t.Fatalf("Invalid labels length. Expected: 0, got: %v", len(labels)) + } +} + +func TestLabelsFromBytesInvalidLength(t *testing.T) { + labels, err := LabelsFromBytes([]byte{0x3, 0xaa, 0xbb}) // short length + if err == nil { + t.Fatal("Expected error, got nil") + } + if len(labels) != 0 { + t.Fatalf("Invalid labels length. Expected: 0, got: %v", len(labels)) + } + if labels != nil { + t.Fatalf("Invalid label. Expected nil, got %v", labels) + } +} + +func TestLabelToBytes(t *testing.T) { + encodedLabel := LabelToBytes("slackware.it") + expected := []byte{ + 0x9, 's', 'l', 'a', 'c', 'k', 'w', 'a', 'r', 'e', + 0x2, 'i', 't', + 0x0, + } + if !bytes.Equal(encodedLabel, expected) { + t.Fatalf("Invalid label. Expected: %v, got: %v", expected, encodedLabel) + } +} + +func TestLabelToBytesZeroLength(t *testing.T) { + encodedLabel := LabelToBytes("") + expected := []byte{0} + if !bytes.Equal(encodedLabel, expected) { + t.Fatalf("Invalid label. Expected: %v, got: %v", expected, encodedLabel) + } +} |