基于Trie(前缀树)实现子网合并
业务需求
今天小管的公司有一项业务需求,将一组带有子网掩码的ip地址进行合并,例如:
"201.203.8.0/24",
"201.203.9.0/24",
表示32为ip地址中前24位是网络号,后8位是主机号。
将其转换为二进制:
11001001 11001011 00001000 00000000 /24
11001001 11001011 00001001 00000000 /24
其中前24位是网络号,我们的目的就是要让网络号位数尽可能的少,
那么对于第三组
00001000
00001001
由于第24位同时包含了0和1,所以可以合并为0000100
,最后的结果就是11001001 11001011 00001000 00000000 /23
,转换为10进制是201.203.8.0/23
。
当然不是所有情况都是这么简单的,例如:
201.203.7.0/24
201.203.8.0/24
虽然相邻,但观察二者的17-24位:
00000111
00001000
第24位同时包含了0和1,但前23位并不相同,所以是不能合并的。
另外,即使ip不相邻也有可能合并:
"194.28.152.0/23",
"194.28.154.0/23",
可以合并为194.28.152.0/22
代码实现
对于这个需求,我第一反应就是利用Trie这一数据结构,Trie的应用十分广泛,包括但不局限于搜索建议、拼写检查、IP路由转发。
在英文单词匹配中,Trie通过大小为26的数组指向下一位。而在本场景下,由于每一位只可能是0或1,Next数组长度可以退化为2,其实本质上就是一颗二叉树。
在这里可以定义数据结构如下,其中Trie代表树中的每一个节点:
// Trie 前缀树
type Trie struct {
IsEnd bool // 表示是否为网络号最后一位
Level int //表示ip地址中的第几位
Parent *Trie // 指向父节点的指针
Next [2]*Trie // 指向子节点0/1的指针
}
输入和输出均为类似194.28.152.0/23
的一串string数组,但我们需要先将其转换为二进制插入Trie,打印时还要重新转换成十进制,所以先写两个Convert函数:
// input: 201.203.8.0/24
// output: 11001001110010110000100000000000 24
func ConvertToBin(input string) (string, int) {
}
// input: 11001001110010110000100/23
// output: 201.203.8.0/23
func ConvertToHex(input string) string {
}
这两个函数实现很简单,因为主要为了实现算法,这里并没有对格式错误的输入情况进行处理,具体代码贴在最后。
接下来看最重要的Insert方法,由于时间紧迫暂时没想到更好的解决方法,我简单画图模拟了一下合并的流程:
由于我们要合并的是网络号,所以遍历mask
- 当遇到IsEnd为true直接退出,因为如果树中已经有了
201.203.8.0/23
,那么再插入201.203.8.0/24
、201.203.9.0/24
都是无意义的。 - 如果t.Next[bit-'0'] == nil,说明要插入的下一个节点不存在,需要创建节点,将Level置为当前的Level+1,Parent设为当前节点。
- 自下而上合并,假设当前节点t为③号节点,由于①是刚刚创建的,isEnd必定为false(不能在第2步初始化时置为true,因为不能确定它是否end节点),所以这里不能直接用t.Next[0].IsEnd && t.Next[1].IsEnd判断。
- 可以观察到,只要t的任一子节点的IsEnd为true并且t.Level为mask-1(倒数第二个网络号)时就可以合并,合并的同时删除①和②,将③的IsEnd置为true,将指针t指向⑤。这时会发现⑤的两个子节点③和④可以继续合并,循环解决即可,其中gap表示当前t指向节点与网络号最后一位的距离,每次加1。
- 如果gap == 1说明没有进行合并,跳出循环;否则说明进行了合并且不能继续合并,直接结束方法。
func (t *Trie) Insert(input string) {
ip, mask := ConvertToBin(input)
for i := 0; i < mask; i++ {
if t.IsEnd {
return
}
bit := ip[i]
if t.Next[bit-'0'] == nil {
t.Next[bit-'0'] = &Trie{
Level: t.Level + 1,
Parent: t,
}
}
gap := 1
for {
if t.Next[0] != nil && t.Next[1] != nil &&
(t.Next[0].IsEnd || t.Next[1].IsEnd) &&
t.Level == mask-gap {
t.Next[0], t.Next[1] = nil, nil
t.IsEnd = true
t = t.Parent
} else if gap == 1 {
break
} else {
return
}
gap++
}
t = t.Next[bit-'0']
}
t.IsEnd = true
}
最后是Print函数,将插入合并后的结果输出:
func (t *Trie) Print() []string {
var res []string
var ip []byte
var dfs func(t *Trie, idx int)
dfs = func(t *Trie, idx int) {
if t == nil {
return
}
if t.IsEnd {
res = append(res, string(ip)+"/"+strconv.Itoa(t.Level))
return
}
if t.Next[0] != nil {
ip = append(ip, '0')
dfs(t.Next[0], idx+1)
ip = ip[:len(ip)-1]
}
if t.Next[1] != nil {
ip = append(ip, '1')
dfs(t.Next[1], idx+1)
ip = ip[:len(ip)-1]
}
}
dfs(t, 0)
return res
}
这里就是基本的二叉树前序遍历,如果遇到IsEnd就把结果存起来。
测试
写了几组测试用例,以一个稍微复杂的例子为例:
测试结果:
源代码
1.merge.go
package merge
import (
"bytes"
"fmt"
"strconv"
"strings"
)
// Trie 0/1前缀树节点
type Trie struct {
IsEnd bool
Level int
Parent *Trie
Next [2]*Trie
}
// NewTrie 返回一个dummy根节点
func NewTrie() *Trie {
return new(Trie)
}
// Insert 插入并合并
func (t *Trie) Insert(input string) {
ip, mask := ConvertToBin(input)
for i := 0; i < mask; i++ {
if t.IsEnd {
return
}
bit := ip[i]
if t.Next[bit-'0'] == nil {
t.Next[bit-'0'] = &Trie{
Level: t.Level + 1,
Parent: t,
}
}
gap := 1
for {
if t.Next[0] != nil && t.Next[1] != nil &&
(t.Next[0].IsEnd || t.Next[1].IsEnd) &&
t.Level == mask-gap {
t.Next[0], t.Next[1] = nil, nil
t.IsEnd = true
t = t.Parent
} else if gap == 1 {
break
} else {
return
}
gap++
}
t = t.Next[bit-'0']
}
t.IsEnd = true
}
// LevelOrder 层序遍历,便于调试
func (t *Trie) LevelOrder() {
var queue []*Trie
queue = append(queue, t)
for len(queue) > 0 {
cur := queue[0]
queue = queue[1:]
fmt.Println(cur)
if cur.Next[0] != nil {
queue = append(queue, cur.Next[0])
}
if cur.Next[1] != nil {
queue = append(queue, cur.Next[1])
}
}
}
// Print 输出结果
func (t *Trie) Print() []string {
var res []string
var ip []byte
var dfs func(t *Trie, idx int)
dfs = func(t *Trie, idx int) {
if t == nil {
return
}
if t.IsEnd {
res = append(res, string(ip)+"/"+strconv.Itoa(t.Level))
return
}
if t.Next[0] != nil {
ip = append(ip, '0')
dfs(t.Next[0], idx+1)
ip = ip[:len(ip)-1]
}
if t.Next[1] != nil {
ip = append(ip, '1')
dfs(t.Next[1], idx+1)
ip = ip[:len(ip)-1]
}
}
dfs(t, 0)
return res
}
// ConvertToBin 十进制转二进制
func ConvertToBin(input string) (string, int) {
arr := strings.Split(input, "/")
pre, post := arr[0], arr[1]
ip := strings.Split(pre, ".")
var buf bytes.Buffer
for _, v := range ip {
num, _ := strconv.Atoi(v)
buf.WriteString(fmt.Sprintf("%08b", num))
}
mask, _ := strconv.Atoi(post)
return buf.String(), mask
}
// ConvertToHex 二进制转十进制
func ConvertToHex(input string) string {
arr := strings.Split(input, "/")
pre, post := arr[0], arr[1]
add := 32 - len(pre)
for i := 0; i < add; i++ {
pre += "0"
}
var buf bytes.Buffer
var point string
for i := 0; i < 4; i++ {
buf.WriteString(point)
numStr := pre[8*i : 8*(i+1)]
num, _ := strconv.ParseInt(numStr, 2, 64)
buf.WriteString(strconv.Itoa(int(num)))
point = "."
}
buf.WriteString("/" + post)
return buf.String()
}
2.merge_test.go
package merge
import (
"fmt"
"sort"
"testing"
)
func TestMerge(t *testing.T) {
var tests = []struct {
input []string
want []string
}{
{
[]string{"201.203.8.0/25", "201.203.8.0/24", "201.203.9.0/24"},
[]string{"201.203.8.0/23"},
},
{
[]string{"210.203.10.0/24", "210.203.11.0/24", "210.203.12.0/24"},
[]string{"210.203.10.0/23", "210.203.12.0/24"},
},
{
[]string{"194.28.152.0/23", "194.28.154.0/23", "194.28.156.0/22", "194.28.144.0/22"},
[]string{"194.28.144.0/22", "194.28.152.0/21"},
},
{
[]string{"194.28.152.0/23", "194.28.154.0/23", "194.28.156.0/22", "194.28.144.0/22", "194.28.148.0/22"},
[]string{"194.28.144.0/20"},
},
}
for _, test := range tests {
trie := NewTrie()
for _, v := range test.input {
trie.Insert(v)
}
var got []string
for _, v := range trie.Print() {
got = append(got, ConvertToHex(v))
}
if convertToCompare(got) != convertToCompare(test.want) {
t.Errorf("Merge(%q): %q, want %q", test.input, got, test.want)
} else {
t.Logf("Merge(%q): %q", test.input, got)
}
}
}
// 将结果排序并转换为string,便于比较
func convertToCompare(strSlice []string) string {
sort.Strings(strSlice)
return fmt.Sprint(strSlice)
}