lib/protocol: faster Luhn algorithm and better testing (#6475)

The previous implementation was very generic; its tests didn't cover the
actual alphabet for device IDs.

Benchmark results on amd64:

name         old time/op    new time/op     delta
Luhnify-8      1.00µs ± 1%     0.28µs ± 4%   -72.38%  (p=0.000 n=9+10)
Unluhnify-8     992ns ± 2%      274ns ± 1%   -72.39%  (p=0.000 n=10+9)
This commit is contained in:
greatroar 2020-03-29 22:28:04 +02:00 committed by GitHub
parent ea5c9176e1
commit 1e2379df1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 60 deletions

View File

@ -165,7 +165,7 @@ func luhnify(s string) (string, error) {
for i := 0; i < 4; i++ { for i := 0; i < 4; i++ {
p := s[i*13 : (i+1)*13] p := s[i*13 : (i+1)*13]
copy(res[i*(13+1):], p) copy(res[i*(13+1):], p)
l, err := luhnBase32.generate(p) l, err := luhn32(p)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -183,7 +183,7 @@ func unluhnify(s string) (string, error) {
for i := 0; i < 4; i++ { for i := 0; i < 4; i++ {
p := s[i*(13+1) : (i+1)*(13+1)-1] p := s[i*(13+1) : (i+1)*(13+1)-1]
copy(res[i*13:], p) copy(res[i*13:], p)
l, err := luhnBase32.generate(p) l, err := luhn32(p)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -2,32 +2,34 @@
package protocol package protocol
import ( import "fmt"
"fmt"
"strings"
)
// An alphabet is a string of N characters, representing the digits of a given var luhnBase32 = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
// base N.
type luhnAlphabet string
var ( func codepoint32(b byte) int {
luhnBase32 luhnAlphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" switch {
) case 'A' <= b && b <= 'Z':
return int(b - 'A')
case '2' <= b && b <= '7':
return int(b + 26 - '2')
default:
return -1
}
}
// generate returns a check digit for the string s, which should be composed // luhn32 returns a check digit for the string s, which should be composed
// of characters from the Alphabet a. // of characters from the alphabet luhnBase32.
// Doesn't follow the actual Luhn algorithm // Doesn't follow the actual Luhn algorithm
// see https://forum.syncthing.net/t/v0-9-0-new-node-id-format/478/6 for more. // see https://forum.syncthing.net/t/v0-9-0-new-node-id-format/478/6 for more.
func (a luhnAlphabet) generate(s string) (rune, error) { func luhn32(s string) (rune, error) {
factor := 1 factor := 1
sum := 0 sum := 0
n := len(a) const n = 32
for i := range s { for i := range s {
codepoint := strings.IndexByte(string(a), s[i]) codepoint := codepoint32(s[i])
if codepoint == -1 { if codepoint == -1 {
return 0, fmt.Errorf("digit %q not valid in alphabet %q", s[i], a) return 0, fmt.Errorf("digit %q not valid in alphabet %q", s[i], luhnBase32)
} }
addend := factor * codepoint addend := factor * codepoint
if factor == 2 { if factor == 2 {
@ -40,16 +42,5 @@ func (a luhnAlphabet) generate(s string) (rune, error) {
} }
remainder := sum % n remainder := sum % n
checkCodepoint := (n - remainder) % n checkCodepoint := (n - remainder) % n
return rune(a[checkCodepoint]), nil return rune(luhnBase32[checkCodepoint]), nil
}
// luhnValidate returns true if the last character of the string s is correct, for
// a string s composed of characters in the alphabet a.
func (a luhnAlphabet) luhnValidate(s string) bool {
t := s[:len(s)-1]
c, err := a.generate(t)
if err != nil {
return false
}
return rune(s[len(s)-1]) == c
} }

View File

@ -3,46 +3,24 @@
package protocol package protocol
import ( import (
"strings"
"testing" "testing"
) )
func TestGenerate(t *testing.T) { func TestLuhn32(t *testing.T) {
// Base 6 Luhn c, err := luhn32("AB725E4GHIQPL3ZFGT")
a := luhnAlphabet("abcdef")
c, err := a.generate("abcdef")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if c != 'e' { if c != 'G' {
t.Errorf("Incorrect check digit %c != e", c) t.Errorf("Incorrect check digit %c != G", c)
} }
// Base 10 Luhn _, err = luhn32("3734EJEKMRHWPZQTWYQ1")
a = luhnAlphabet("0123456789")
c, err = a.generate("7992739871")
if err != nil {
t.Fatal(err)
}
if c != '3' {
t.Errorf("Incorrect check digit %c != 3", c)
}
}
func TestInvalidString(t *testing.T) {
a := luhnAlphabet("ABC")
_, err := a.generate("7992739871")
t.Log(err)
if err == nil { if err == nil {
t.Error("Unexpected nil error") t.Error("Unexpected nil error")
} }
} if !strings.Contains(err.Error(), "'1'") {
t.Errorf("luhn32 should have errored on digit '1', got %v", err)
func TestValidate(t *testing.T) {
a := luhnAlphabet("abcdef")
if !a.luhnValidate("abcdefe") {
t.Errorf("Incorrect validation response for abcdefe")
}
if a.luhnValidate("abcdefd") {
t.Errorf("Incorrect validation response for abcdefd")
} }
} }