2

快速转换string到int

 2 years ago
source link: https://sikasjc.github.io/2021/08/07/string2int/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

原文见此:https://johnnylee-sde.github.io/Fast-numeric-string-to-int/

利用位运算和64位CPU的优势,实现快速的转换string到int,并使用Go来验证。

最简单的版本

  • 常见的字符串转数字的代码如下所示:
// given num[] - ASCII chars containing decimal digits 0-9
int sum = 0;
for (int i = 0; i <= 7; i++)
{
sum = (sum * 10) + (num[i] - '0');
}
  • 一种非常直接优化的方式是将循环展开
int sum;
sum = (num[0] - '0') * 10000000 +
(num[1] - '0') * 1000000 +
(num[2] - '0') * 100000 +
(num[3] - '0') * 10000 +
(num[4] - '0') * 1000 +
(num[5] - '0') * 100 +
(num[6] - '0') * 10 +
(num[7] - '0');
  • 用Golang测试下,循环展开的版本是否有优化
package string2int

import (
"strconv"
"testing"
)

func loop(str string) int {
num := 0
for i := 0; i < 8; i++ {
num = num*10 + int(str[i]-'0')
}
return num
}

func loop2(str string) int {
num := 0
for i := 0; i < len(str); i++ {
num = num*10 + int(str[i]-'0')
}
return num
}

func unrollLoop(str string) int {
num := int(str[0]-'0')*10000000 +
int(str[1]-'0')*1000000 +
int(str[2]-'0')*100000 +
int(str[3]-'0')*10000 +
int(str[4]-'0')*1000 +
int(str[5]-'0')*100 +
int(str[6]-'0')*10 +
int(str[7]-'0')
return num
}

func Test_String2Int(t *testing.T) {
str := "12345678"
n, _ := strconv.Atoi(str)

if loop(str) != n {
t.Errorf("loop error, %v != %v\n", loop(str), n)
}
if loop2(str) != n {
t.Errorf("loop2 error, %v != %v\n", loop2(str), n)
}
if unrollLoop(str) != n {
t.Errorf("unroll loop error, %v != %v\n", unrollLoop(str), n)
}
}

func Benchmark_String2Int(b *testing.B) {
str := "12345678"
b.Run("strconv.Atoi", func(b *testing.B) {
for i := 0; i < b.N; i++ {
strconv.Atoi(str)
}
})
b.Run("loop", func(b *testing.B) {
for i := 0; i < b.N; i++ {
loop(str)
}
})
b.Run("loop2", func(b *testing.B) {
for i := 0; i < b.N; i++ {
loop2(str)
}
})
b.Run("unroll loop", func(b *testing.B) {
for i := 0; i < b.N; i++ {
unrollLoop(str)
}
})
}

--------------------------------------------------------------------
goos: darwin
goarch: amd64
pkg: mine/mock/benchmark/string2int
cpu: Intel(R) Core(TM) i7-8750H CPU @ 2.20GHz
Benchmark_String2Int
Benchmark_String2Int/strconv.Atoi
Benchmark_String2Int/strconv.Atoi-12 138586400 8.996 ns/op
Benchmark_String2Int/loop
Benchmark_String2Int/loop-12 338002887 3.550 ns/op
Benchmark_String2Int/loop2
Benchmark_String2Int/loop2-12 410646234 3.205 ns/op
Benchmark_String2Int/unroll_loop
Benchmark_String2Int/unroll_loop-12 724936510 1.436 ns/op
PASS

// 循环展开的版本优势是明显的,并且比strconv要快很多,因为strconv要做一些检查。
// 有意思的地方:loop2要比loop快,还不太清楚为什么
  • 两种方法实际上都是O(n)的,n为字符串长度
  • 可以通过计算loads、adds/subtracts、shifts和的multiplies来估计代码的开销:
    • 对于展开循环的版本:

      • 8 loads
      • 8 subtracts, 7 adds, 8 adds to index into num array
      • 7 multiplies
    • 这里假设除了multiplies外每个操作的开销一致,一般multiplies开销会更大一些。

    • 因此是,31次操作 + 7 multiplies

更快的转换

现在 64 位 CPU 和操作系统很常见,我们可以在这个问题上释放 64 位 CPU 寄存器的全部力量。

  • 在 ASCII 字符集中,数字字符('0' ~ '9')的范围是 0x30 ~ 0x39 (48-57)
  • 如果将每个数字 ASCII 字符与 0x0F 按位与,会将数字 ASCII 字符转换为该数字字符对应的十进制数。
'0' ~ '9'             =>   0x30 ~ 0x39
(0x30 ~ 0x39) & 0x0F => 0x0 ~ 0x9
  • 将 8 位数字字符串加载到 64 位 CPU 寄存器中,在 Intel CPU(little-endian)上表现如下
// given the string "12345678", 
// on little-endian Intel CPUs we see the reversed:
sum = 0x3837363534333231
  • 0x0F0F0F0F0F0F0F0F 按位与后
// given the string "12345678",
// bitwise-AND with 0x0F0F0F0F0F0F0F0F
sum = *((long long *)num) & 0x0F0F0F0F0F0F0F0F;
sum == 0x0807060504030201
  • 由于Intel CPU为little-endian,加载进来的数字实际上低位在前,高位在后。
// given the string "12345678", 
// on little-endian Intel CPUs we see the reversed:
sum = 0x3837363534333231
  • 因此我们需要做一些调整,将个位的低位和十位的高位合并为一个数
    • 按位与所有高位数字并将乘以10
    • 右移低位数字到与高位数字相同的位置
    • 取上面两个数的和
// isolate the high digit, multiply by 10,
// shift over the low digit and add in
sum = ((sum & 0x000F000F000F000F) * 10) +
((sum >> 8) & 0x000F000F000F000F);

// sum = 0x0807060504030201
// (sum & 0x000F000F000F000F) * 10 = 0x0007000500030001 * 10
// (sum >> 8) & 0x000F000F000F000F = 0x0008000600040002
// 取两数之和后: [78]代表10进制数字对应的16进制数字
// sum = 0x00[78]00[56]00[34]00[12]
  • 扩展到更大的范围
// numbers are in range 0-99 (0x0-0x63) now
// - isolate high number (use 0x7F which encompasses number range)
// - multiply by 100 to move high number into
// thousands & hundreds position
// - shift low number over to tens and ones position
// - add the two numbers together
sum = ((sum & 0x0000007F0000007F) * 100) + ((sum >> 16) & 0x0000007F0000007F);

// sum = 0x00[78]00[56]00[34]00[12]
// (sum & 0x0000007F0000007F) * 100 = 0x00000000[56]000000[12] * 100
// (sum >> 16) & 0x0000007F0000007F) = 0x000000[78]000000[34]
// 取两数之和后:
// sum = 0x0000[5678]0000[1234]
// numbers are in range 0-9,999 (0x0-0x270F) now
// isolate high number (use 0x3FFF which covers number range)
// then multiply by 10000 to move high number into position
// shift low number over and isolate
// add the two numbers together
sum = ((sum & 0x3FFF) * 10000) + ((sum >> 32) & 0x3FFF);

// sum = 0x0000[5678]0000[1234]
// (sum & 0x3FFF) * 10000 = 0x000000000000[1234] * 10000
// (sum >> 32) & 0x3FFF = 0x000000000000[5678]
// 取两数之和后:
// sum = 0x00000000[12345678]

最终的算法

// given num[] - ASCII chars containing decimal digits 0-9
long long sum;
sum = *((long long*)num) & 0xFFFFFFFF;
sum = ((sum & 0x0F0F0F0F) * 10 ) + ((sum >> 8) & 0x0F0F0F0F);
sum = ((sum & 0x0007F0007F) * 100) + ((sum >> 16) & 0x0007F0007F);
sum = ((sum & 0x3FFF) * 10000) + ((sum >> 32) & 0x3FFF);
  • 现在的时间复杂度为O(lg n), 会执行
    • 1 load
    • 7 bitwise ANDs
    • 3 right shifts
    • 3 adds
    • 3 multiplies
  • 和循环展开的版本的比较一下
Algorithm       |  Ops   | Multiplies
Unrolled loop | 31 | 7
SIMD | 14 | 3
sum = *(long long *)num;
sum = (sum & 0x0F0F0F0F0F0F0F0F) * 2561 >> 8;
sum = (sum & 0x00FF00FF00FF00FF) * 6553601 >> 16;
sum = (sum & 0x0000FFFF0000FFFF) * 42949672960001 >> 32;

这些魔法数字是哪来的?

  • 上面的方法是通过右移不断的累加高位数字
  • 这里其实也是类似的,只是通过左移来处理高位数字的
  • 最后右移,去掉引入的部分
// 2561

sum = (((256 * 10) * sum) + 1 * sum);
// multiply by 256 is the same as left shift by 8
== ((10 * sum) << 8) + (1 * sum);
sum = sum >> 8;

// sum = 0x0807060504030201
// (256 * 10) * sum = (sum << 8) * 10 = 0x0706050403020100 * 10
// (256 * 10) * sum + 1 * sum = 0x0706050403020100 * 10 + 0x0807060504030201
// = 0x[78][67][56][45][34][23][12][01]
// sum >> 8 = 0x00[78][67][56][45][34][23][12]
// 6553601

// number groups are in range 0-99 now
sum = (sum & 0x00FF00FF00FF00FF) * 6553601 >> 16;

// sum = 0x00[78][67][56][45][34][23][12]
// sum & 0x00FF00FF00FF00FF = 0x00[78]00[56]00[34]00[12]

// sum = 0x00[78]00[56]00[34]00[12]
// (65536 * 100) * sum = (sum << 16) * 100 = 0x00[56]00[34]00[12]0000 * 100
// (65536 * 100) * sum + 1 * sum = 0x00[56]00[34]00[12]0000 * 100 + 0x00[78]00[56]00[34]00[12]
// = 0x[5678][3456][1234][0012]
// sum >> 16 = 0x0000[5678][3456][1234]

// 42949672960001

// number groups are in range 0-9,999 now
sum = (sum & 0x0000FFFF0000FFFF) * 42949672960001 >> 32;

// sum = 0x0000[5678][3456][1234]
// sum & 0x0000FFFF0000FFFF = 0x0000[5678]0000[1234]

// sum = 0x0000[5678]0000[1234]
// (4294967296 * 10000) * sum) = (sum << 32) * 10000 = 0x0000[1234]00000000 * 10000
// (4294967296 * 10000) * sum) + 1 * sum = 0x0000[1234]00000000 * 10000 + 0x0000[5678]0000[1234]
// = 0x[12345678][00001234]
// sum >> 32 = 0x00000000[12345678]
  • 这样进一步减少了操作的步骤
    • 1 load
    • 3 bitwise ANDs
    • 3 right shifts
    • 3 multiplies
  • 和之前的版本比较一下
Algorithm       |  Ops  | Multiplies
Unrolled loop | 31 | 7
SIMD | 14 | 3
SIMD 2 | 7 | 3

用Golang模拟

goos: darwin
goarch: amd64
pkg: mine/mock/benchmark/string2int
cpu: Intel(R) Core(TM) i7-8750H CPU @ 2.20GHz
Benchmark_String2Int
Benchmark_String2Int/strconv.Aoti
Benchmark_String2Int/strconv.Aoti-12 127020 8282 ns/op
Benchmark_String2Int/loop
Benchmark_String2Int/loop-12 327012 3719 ns/op
Benchmark_String2Int/loop2
Benchmark_String2Int/loop2-12 406800 2974 ns/op
Benchmark_String2Int/unroll_loop
Benchmark_String2Int/unroll_loop-12 802428 1455 ns/op
Benchmark_String2Int/simd
Benchmark_String2Int/simd-12 4299562 330.2 ns/op
Benchmark_String2Int/simd2
Benchmark_String2Int/simd2-12 3641539 301.2 ns/op
PASS
  • SIMD 和 SIMD2的优势是非常明显的
  • 发现SIMD2 相比于SIMD并没有明显优势(benchmark比较的是1000次的转换开销),还不确定是什么原因
    • 毕竟SIMD 2减少了常规操作,但是引入了相对大的整数乘法
    • 也有可能是常规操作确实开销太小了,不明显
package string2int

import (
"strconv"
"testing"
)

func loop(str string) int {
num := 0
for i := 0; i < 8; i++ {
num = num*10 + int(str[i]-'0')
}
return num
}

func loop2(str string) int {
num := 0
for i := 0; i < len(str); i++ {
num = num*10 + int(str[i]-'0')
}
return num
}

func unrollLoop(str string) int {
num := int(str[0]-'0')*10000000 +
int(str[1]-'0')*1000000 +
int(str[2]-'0')*100000 +
int(str[3]-'0')*10000 +
int(str[4]-'0')*1000 +
int(str[5]-'0')*100 +
int(str[6]-'0')*10 +
int(str[7]-'0')
return num
}

func simd() int {
str := 0x3837363534333231
num := str & 0x0f0f0f0f0f0f0f0f
num = (num&0x000f000f000f000f)*10 + (num>>8)&0x000f000f000f000f
num = (num&0x0000007f0000007f)*100 + (num>>16)&0x0000007f0000007f
num = (num&0x3fff)*10000 + (num>>32)&0x3fff
return num
}

func simd2() int {
str := 0x3837363534333231
num := (str & 0x0f0f0f0f0f0f0f0f) * 2561 >> 8
num = (num & 0x00ff00ff00ff00ff) * 6553601 >> 16
num = (num & 0x0000ffff0000ffff) * 42949672960001 >> 32
return num
}

func Test_String2Int(t *testing.T) {
str := "12345678"
n, _ := strconv.Atoi(str)

if loop(str) != n {
t.Errorf("loop error, %v != %v\n", loop(str), n)
}
if loop2(str) != n {
t.Errorf("loop2 error, %v != %v\n", loop2(str), n)
}
if unrollLoop(str) != n {
t.Errorf("unroll loop error, %v != %v\n", unrollLoop(str), n)
}
if simd() != n {
t.Errorf("simd error, %v != %v\n", simd(), n)
}
if simd2() != n {
t.Errorf("simd2 error, %v != %v\n", simd2(), n)
}
}

func Benchmark_String2Int(b *testing.B) {
str := "12345678"
b.Run("strconv.Aoti", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for j := 0; j < 1000; j++ {
strconv.Atoi(str)
}
}
})
b.Run("loop", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for j := 0; j < 1000; j++ {
loop(str)
}
}
})
b.Run("loop2", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for j := 0; j < 1000; j++ {
loop2(str)
}
}
})
b.Run("unroll loop", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for j := 0; j < 1000; j++ {
unrollLoop(str)
}
}
})
b.Run("simd", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for j := 0; j < 1000; j++ {
simd()
}
}
})
b.Run("simd2", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for j := 0; j < 1000; j++ {
simd2()
}
}
})
}

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK