From 1e2379df1bb199b80c8fd2c7f472386d2cf8cfeb Mon Sep 17 00:00:00 2001 From: greatroar <61184462+greatroar@users.noreply.github.com> Date: Sun, 29 Mar 2020 22:28:04 +0200 Subject: [PATCH] lib/protocol: faster Luhn algorithm and better testing (#6475) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous implementation was very generic; its tests didn't cover the actual alphabet for device IDs. Benchmark results on amd64: name old time/op new time/op delta Luhnify-8 1.00µs ± 1% 0.28µs ± 4% -72.38% (p=0.000 n=9+10) Unluhnify-8 992ns ± 2% 274ns ± 1% -72.39% (p=0.000 n=10+9) --- lib/protocol/deviceid.go | 4 ++-- lib/protocol/luhn.go | 47 ++++++++++++++++----------------------- lib/protocol/luhn_test.go | 38 +++++++------------------------ 3 files changed, 29 insertions(+), 60 deletions(-) diff --git a/lib/protocol/deviceid.go b/lib/protocol/deviceid.go index 015080387..8cac191ca 100644 --- a/lib/protocol/deviceid.go +++ b/lib/protocol/deviceid.go @@ -165,7 +165,7 @@ func luhnify(s string) (string, error) { for i := 0; i < 4; i++ { p := s[i*13 : (i+1)*13] copy(res[i*(13+1):], p) - l, err := luhnBase32.generate(p) + l, err := luhn32(p) if err != nil { return "", err } @@ -183,7 +183,7 @@ func unluhnify(s string) (string, error) { for i := 0; i < 4; i++ { p := s[i*(13+1) : (i+1)*(13+1)-1] copy(res[i*13:], p) - l, err := luhnBase32.generate(p) + l, err := luhn32(p) if err != nil { return "", err } diff --git a/lib/protocol/luhn.go b/lib/protocol/luhn.go index 49e318064..340276541 100644 --- a/lib/protocol/luhn.go +++ b/lib/protocol/luhn.go @@ -2,32 +2,34 @@ package protocol -import ( - "fmt" - "strings" -) +import "fmt" -// An alphabet is a string of N characters, representing the digits of a given -// base N. -type luhnAlphabet string +var luhnBase32 = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" -var ( - luhnBase32 luhnAlphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" -) +func codepoint32(b byte) int { + switch { + case 'A' <= b && b <= 'Z': + return int(b - 'A') + case '2' <= b && b <= '7': + return int(b + 26 - '2') + default: + return -1 + } +} -// generate returns a check digit for the string s, which should be composed -// of characters from the Alphabet a. +// luhn32 returns a check digit for the string s, which should be composed +// of characters from the alphabet luhnBase32. // Doesn't follow the actual Luhn algorithm // see https://forum.syncthing.net/t/v0-9-0-new-node-id-format/478/6 for more. -func (a luhnAlphabet) generate(s string) (rune, error) { +func luhn32(s string) (rune, error) { factor := 1 sum := 0 - n := len(a) + const n = 32 for i := range s { - codepoint := strings.IndexByte(string(a), s[i]) + codepoint := codepoint32(s[i]) if codepoint == -1 { - return 0, fmt.Errorf("digit %q not valid in alphabet %q", s[i], a) + return 0, fmt.Errorf("digit %q not valid in alphabet %q", s[i], luhnBase32) } addend := factor * codepoint if factor == 2 { @@ -40,16 +42,5 @@ func (a luhnAlphabet) generate(s string) (rune, error) { } remainder := sum % n checkCodepoint := (n - remainder) % n - return rune(a[checkCodepoint]), nil -} - -// luhnValidate returns true if the last character of the string s is correct, for -// a string s composed of characters in the alphabet a. -func (a luhnAlphabet) luhnValidate(s string) bool { - t := s[:len(s)-1] - c, err := a.generate(t) - if err != nil { - return false - } - return rune(s[len(s)-1]) == c + return rune(luhnBase32[checkCodepoint]), nil } diff --git a/lib/protocol/luhn_test.go b/lib/protocol/luhn_test.go index fe7f80c11..c53c6adac 100644 --- a/lib/protocol/luhn_test.go +++ b/lib/protocol/luhn_test.go @@ -3,46 +3,24 @@ package protocol import ( + "strings" "testing" ) -func TestGenerate(t *testing.T) { - // Base 6 Luhn - a := luhnAlphabet("abcdef") - c, err := a.generate("abcdef") +func TestLuhn32(t *testing.T) { + c, err := luhn32("AB725E4GHIQPL3ZFGT") if err != nil { t.Fatal(err) } - if c != 'e' { - t.Errorf("Incorrect check digit %c != e", c) + if c != 'G' { + t.Errorf("Incorrect check digit %c != G", c) } - // Base 10 Luhn - a = luhnAlphabet("0123456789") - c, err = a.generate("7992739871") - if err != nil { - t.Fatal(err) - } - if c != '3' { - t.Errorf("Incorrect check digit %c != 3", c) - } -} - -func TestInvalidString(t *testing.T) { - a := luhnAlphabet("ABC") - _, err := a.generate("7992739871") - t.Log(err) + _, err = luhn32("3734EJEKMRHWPZQTWYQ1") if err == nil { t.Error("Unexpected nil error") } -} - -func TestValidate(t *testing.T) { - a := luhnAlphabet("abcdef") - if !a.luhnValidate("abcdefe") { - t.Errorf("Incorrect validation response for abcdefe") - } - if a.luhnValidate("abcdefd") { - t.Errorf("Incorrect validation response for abcdefd") + if !strings.Contains(err.Error(), "'1'") { + t.Errorf("luhn32 should have errored on digit '1', got %v", err) } }