Add some tests (#553)

This commit is contained in:
Thomas Miceli
2025-10-31 15:37:45 +07:00
committed by GitHub
parent 8129906b02
commit 3957dfb3ea
4 changed files with 1481 additions and 0 deletions

View File

@@ -0,0 +1,430 @@
package totp
import (
"bytes"
"crypto/aes"
"testing"
)
func TestAESEncrypt(t *testing.T) {
tests := []struct {
name string
key []byte
text []byte
wantErr bool
}{
{
name: "basic encryption with 16-byte key",
key: []byte("1234567890123456"), // 16 bytes (AES-128)
text: []byte("hello world"),
wantErr: false,
},
{
name: "basic encryption with 24-byte key",
key: []byte("123456789012345678901234"), // 24 bytes (AES-192)
text: []byte("hello world"),
wantErr: false,
},
{
name: "basic encryption with 32-byte key",
key: []byte("12345678901234567890123456789012"), // 32 bytes (AES-256)
text: []byte("hello world"),
wantErr: false,
},
{
name: "empty text",
key: []byte("1234567890123456"),
text: []byte(""),
wantErr: false,
},
{
name: "long text",
key: []byte("1234567890123456"),
text: []byte("This is a much longer text that spans multiple blocks and should be encrypted properly without any issues"),
wantErr: false,
},
{
name: "binary data",
key: []byte("1234567890123456"),
text: []byte{0x00, 0x01, 0x02, 0x03, 0xFF, 0xFE, 0xFD},
wantErr: false,
},
{
name: "invalid key length - too short",
key: []byte("short"),
text: []byte("hello world"),
wantErr: true,
},
{
name: "invalid key length - 17 bytes",
key: []byte("12345678901234567"), // 17 bytes (invalid)
text: []byte("hello world"),
wantErr: true,
},
{
name: "nil key",
key: nil,
text: []byte("hello world"),
wantErr: true,
},
{
name: "empty key",
key: []byte(""),
text: []byte("hello world"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ciphertext, err := AESEncrypt(tt.key, tt.text)
if (err != nil) != tt.wantErr {
t.Errorf("AESEncrypt() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
// Verify ciphertext is not empty
if len(ciphertext) == 0 {
t.Error("AESEncrypt() returned empty ciphertext")
}
// Verify ciphertext length is correct (IV + encrypted text)
expectedLen := aes.BlockSize + len(tt.text)
if len(ciphertext) != expectedLen {
t.Errorf("AESEncrypt() ciphertext length = %d, want %d", len(ciphertext), expectedLen)
}
// Verify ciphertext is different from plaintext (unless text is empty)
if len(tt.text) > 0 && bytes.Equal(ciphertext[aes.BlockSize:], tt.text) {
t.Error("AESEncrypt() ciphertext matches plaintext")
}
// Verify IV is present and non-zero
iv := ciphertext[:aes.BlockSize]
allZeros := true
for _, b := range iv {
if b != 0 {
allZeros = false
break
}
}
if allZeros {
t.Error("AESEncrypt() IV is all zeros")
}
}
})
}
}
func TestAESDecrypt(t *testing.T) {
validKey := []byte("1234567890123456")
validText := []byte("hello world")
// Encrypt some data to use for valid test cases
validCiphertext, err := AESEncrypt(validKey, validText)
if err != nil {
t.Fatalf("Failed to create valid ciphertext: %v", err)
}
tests := []struct {
name string
key []byte
ciphertext []byte
wantErr bool
}{
{
name: "valid decryption",
key: validKey,
ciphertext: validCiphertext,
wantErr: false,
},
{
name: "ciphertext too short - empty",
key: validKey,
ciphertext: []byte(""),
wantErr: true,
},
{
name: "ciphertext too short - less than block size",
key: validKey,
ciphertext: []byte("short"),
wantErr: true,
},
{
name: "ciphertext exactly block size (IV only, no data)",
key: validKey,
ciphertext: make([]byte, aes.BlockSize),
wantErr: false,
},
{
name: "invalid key length",
key: []byte("short"),
ciphertext: validCiphertext,
wantErr: true,
},
{
name: "wrong key",
key: []byte("6543210987654321"),
ciphertext: validCiphertext,
wantErr: false, // Decryption succeeds but produces garbage
},
{
name: "nil key",
key: nil,
ciphertext: validCiphertext,
wantErr: true,
},
{
name: "nil ciphertext",
key: validKey,
ciphertext: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
plaintext, err := AESDecrypt(tt.key, tt.ciphertext)
if (err != nil) != tt.wantErr {
t.Errorf("AESDecrypt() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
// For valid decryption with correct key, verify we get original text
if tt.name == "valid decryption" && !bytes.Equal(plaintext, validText) {
t.Errorf("AESDecrypt() plaintext = %v, want %v", plaintext, validText)
}
// For ciphertext with only IV, plaintext should be empty
if tt.name == "ciphertext exactly block size (IV only, no data)" && len(plaintext) != 0 {
t.Errorf("AESDecrypt() plaintext length = %d, want 0", len(plaintext))
}
}
})
}
}
func TestAESEncryptDecrypt_RoundTrip(t *testing.T) {
tests := []struct {
name string
key []byte
text []byte
}{
{
name: "basic round trip",
key: []byte("1234567890123456"),
text: []byte("hello world"),
},
{
name: "empty text round trip",
key: []byte("1234567890123456"),
text: []byte(""),
},
{
name: "long text round trip",
key: []byte("1234567890123456"),
text: []byte("This is a very long text that contains multiple blocks of data and should be encrypted and decrypted correctly without any data loss or corruption"),
},
{
name: "binary data round trip",
key: []byte("1234567890123456"),
text: []byte{0x00, 0x01, 0x02, 0x03, 0xFF, 0xFE, 0xFD, 0xFC},
},
{
name: "unicode text round trip",
key: []byte("1234567890123456"),
text: []byte("Hello 世界! 🔐 Encryption"),
},
{
name: "AES-192 round trip",
key: []byte("123456789012345678901234"),
text: []byte("testing AES-192"),
},
{
name: "AES-256 round trip",
key: []byte("12345678901234567890123456789012"),
text: []byte("testing AES-256"),
},
{
name: "special characters",
key: []byte("1234567890123456"),
text: []byte("!@#$%^&*()_+-=[]{}|;':\",./<>?"),
},
{
name: "newlines and tabs",
key: []byte("1234567890123456"),
text: []byte("line1\nline2\tline3\r\nline4"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Encrypt
ciphertext, err := AESEncrypt(tt.key, tt.text)
if err != nil {
t.Fatalf("AESEncrypt() failed: %v", err)
}
// Decrypt
plaintext, err := AESDecrypt(tt.key, ciphertext)
if err != nil {
t.Fatalf("AESDecrypt() failed: %v", err)
}
// Verify plaintext matches original
if !bytes.Equal(plaintext, tt.text) {
t.Errorf("Round trip failed: got %v, want %v", plaintext, tt.text)
}
})
}
}
func TestAESEncrypt_Uniqueness(t *testing.T) {
key := []byte("1234567890123456")
text := []byte("hello world")
iterations := 10
ciphertexts := make(map[string]bool)
for i := 0; i < iterations; i++ {
ciphertext, err := AESEncrypt(key, text)
if err != nil {
t.Fatalf("Iteration %d failed: %v", i, err)
}
// Each encryption should produce different ciphertext (due to random IV)
ciphertextStr := string(ciphertext)
if ciphertexts[ciphertextStr] {
t.Errorf("Duplicate ciphertext generated at iteration %d", i)
}
ciphertexts[ciphertextStr] = true
// But all should decrypt to the same plaintext
plaintext, err := AESDecrypt(key, ciphertext)
if err != nil {
t.Fatalf("Iteration %d decryption failed: %v", i, err)
}
if !bytes.Equal(plaintext, text) {
t.Errorf("Iteration %d: decrypted text doesn't match original", i)
}
}
}
func TestAESEncrypt_IVUniqueness(t *testing.T) {
key := []byte("1234567890123456")
text := []byte("test data")
iterations := 20
ivs := make(map[string]bool)
for i := 0; i < iterations; i++ {
ciphertext, err := AESEncrypt(key, text)
if err != nil {
t.Fatalf("Iteration %d failed: %v", i, err)
}
// Extract IV (first block)
iv := ciphertext[:aes.BlockSize]
ivStr := string(iv)
// Each IV should be unique
if ivs[ivStr] {
t.Errorf("Duplicate IV generated at iteration %d", i)
}
ivs[ivStr] = true
}
}
func TestAESDecrypt_WrongKey(t *testing.T) {
originalKey := []byte("1234567890123456")
wrongKey := []byte("6543210987654321")
text := []byte("secret message")
// Encrypt with original key
ciphertext, err := AESEncrypt(originalKey, text)
if err != nil {
t.Fatalf("AESEncrypt() failed: %v", err)
}
// Decrypt with wrong key - should not error but produce wrong plaintext
plaintext, err := AESDecrypt(wrongKey, ciphertext)
if err != nil {
t.Fatalf("AESDecrypt() with wrong key failed: %v", err)
}
// Plaintext should be different from original
if bytes.Equal(plaintext, text) {
t.Error("AESDecrypt() with wrong key produced correct plaintext")
}
}
func TestAESDecrypt_CorruptedCiphertext(t *testing.T) {
key := []byte("1234567890123456")
text := []byte("hello world")
// Encrypt
ciphertext, err := AESEncrypt(key, text)
if err != nil {
t.Fatalf("AESEncrypt() failed: %v", err)
}
// Corrupt the ciphertext (flip a bit in the encrypted data, not the IV)
if len(ciphertext) > aes.BlockSize {
corruptedCiphertext := make([]byte, len(ciphertext))
copy(corruptedCiphertext, ciphertext)
corruptedCiphertext[aes.BlockSize] ^= 0xFF
// Decrypt corrupted ciphertext - should not error but produce wrong plaintext
plaintext, err := AESDecrypt(key, corruptedCiphertext)
if err != nil {
t.Fatalf("AESDecrypt() with corrupted ciphertext failed: %v", err)
}
// Plaintext should be different from original
if bytes.Equal(plaintext, text) {
t.Error("AESDecrypt() with corrupted ciphertext produced correct plaintext")
}
}
}
func TestAESEncryptDecrypt_DifferentKeySizes(t *testing.T) {
tests := []struct {
name string
keySize int
}{
{"AES-128", 16},
{"AES-192", 24},
{"AES-256", 32},
}
text := []byte("test message for different key sizes")
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Generate key of specified size
key := make([]byte, tt.keySize)
for i := range key {
key[i] = byte(i)
}
// Encrypt
ciphertext, err := AESEncrypt(key, text)
if err != nil {
t.Fatalf("AESEncrypt() failed: %v", err)
}
// Decrypt
plaintext, err := AESDecrypt(key, ciphertext)
if err != nil {
t.Fatalf("AESDecrypt() failed: %v", err)
}
// Verify
if !bytes.Equal(plaintext, text) {
t.Errorf("Round trip failed for %s", tt.name)
}
})
}
}

View File

@@ -0,0 +1,431 @@
package totp
import (
"encoding/base64"
"strings"
"sync"
"testing"
"time"
"github.com/pquerna/otp/totp"
)
func TestGenerateQRCode(t *testing.T) {
tests := []struct {
name string
username string
siteUrl string
secret []byte
wantErr bool
}{
{
name: "basic generation with nil secret",
username: "testuser",
siteUrl: "opengist.io",
secret: nil,
wantErr: false,
},
{
name: "basic generation with provided secret",
username: "testuser",
siteUrl: "opengist.io",
secret: []byte("1234567890123456"),
wantErr: false,
},
{
name: "username with special characters",
username: "test.user",
siteUrl: "opengist.io",
secret: nil,
wantErr: false,
},
{
name: "site URL with protocol and port",
username: "testuser",
siteUrl: "https://opengist.io:6157",
secret: nil,
wantErr: false,
},
{
name: "empty username",
username: "",
siteUrl: "opengist.io",
secret: nil,
wantErr: true,
},
{
name: "empty site URL",
username: "testuser",
siteUrl: "",
secret: nil,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
secretStr, qrcode, secretBytes, err := GenerateQRCode(tt.username, tt.siteUrl, tt.secret)
if (err != nil) != tt.wantErr {
t.Errorf("GenerateQRCode() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
// Verify secret string is not empty
if secretStr == "" {
t.Error("GenerateQRCode() returned empty secret string")
}
// Verify QR code image is generated
if qrcode == "" {
t.Error("GenerateQRCode() returned empty QR code")
}
// Verify QR code has correct data URI prefix
if !strings.HasPrefix(string(qrcode), "data:image/png;base64,") {
t.Errorf("QR code does not have correct data URI prefix: %s", qrcode[:50])
}
// Verify QR code is valid base64 after prefix
base64Data := strings.TrimPrefix(string(qrcode), "data:image/png;base64,")
_, err := base64.StdEncoding.DecodeString(base64Data)
if err != nil {
t.Errorf("QR code base64 data is invalid: %v", err)
}
// Verify secret bytes are returned
if secretBytes == nil {
t.Error("GenerateQRCode() returned nil secret bytes")
}
// Verify secret bytes have correct length
if len(secretBytes) != secretSize {
t.Errorf("Secret bytes length = %d, want %d", len(secretBytes), secretSize)
}
// If a secret was provided, verify it matches what was returned
if tt.secret != nil && string(secretBytes) != string(tt.secret) {
t.Error("Returned secret bytes do not match provided secret")
}
}
})
}
}
func TestGenerateQRCode_SecretUniqueness(t *testing.T) {
username := "testuser"
siteUrl := "opengist.io"
iterations := 10
secrets := make(map[string]bool)
secretBytes := make(map[string]bool)
for i := 0; i < iterations; i++ {
secretStr, _, secret, err := GenerateQRCode(username, siteUrl, nil)
if err != nil {
t.Fatalf("Iteration %d failed: %v", i, err)
}
// Check secret string uniqueness
if secrets[secretStr] {
t.Errorf("Duplicate secret string generated at iteration %d", i)
}
secrets[secretStr] = true
// Check secret bytes uniqueness
secretKey := string(secret)
if secretBytes[secretKey] {
t.Errorf("Duplicate secret bytes generated at iteration %d", i)
}
secretBytes[secretKey] = true
}
}
func TestGenerateQRCode_WithProvidedSecret(t *testing.T) {
username := "testuser"
siteUrl := "opengist.io"
providedSecret := []byte("mysecret12345678")
// Generate QR code multiple times with the same secret
secretStr1, _, secret1, err := GenerateQRCode(username, siteUrl, providedSecret)
if err != nil {
t.Fatalf("First generation failed: %v", err)
}
secretStr2, _, secret2, err := GenerateQRCode(username, siteUrl, providedSecret)
if err != nil {
t.Fatalf("Second generation failed: %v", err)
}
// Secret strings should be the same when using the same input secret
if secretStr1 != secretStr2 {
t.Error("Secret strings differ when using the same provided secret")
}
// Secret bytes should match the provided secret
if string(secret1) != string(providedSecret) {
t.Error("Returned secret bytes do not match provided secret (first call)")
}
if string(secret2) != string(providedSecret) {
t.Error("Returned secret bytes do not match provided secret (second call)")
}
}
func TestGenerateQRCode_ConcurrentGeneration(t *testing.T) {
username := "testuser"
siteUrl := "opengist.io"
concurrency := 10
type result struct {
secretStr string
secretBytes []byte
err error
}
results := make(chan result, concurrency)
var wg sync.WaitGroup
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
secretStr, _, secretBytes, err := GenerateQRCode(username, siteUrl, nil)
results <- result{secretStr: secretStr, secretBytes: secretBytes, err: err}
}()
}
wg.Wait()
close(results)
secrets := make(map[string]bool)
for res := range results {
if res.err != nil {
t.Errorf("Concurrent generation failed: %v", res.err)
continue
}
// Check for duplicates
if secrets[res.secretStr] {
t.Error("Duplicate secret generated in concurrent test")
}
secrets[res.secretStr] = true
}
}
func TestValidate(t *testing.T) {
// Generate a valid secret for testing
_, _, secret, err := GenerateQRCode("testuser", "opengist.io", nil)
if err != nil {
t.Fatalf("Failed to generate secret: %v", err)
}
// Convert secret bytes to base32 string for TOTP
secretStr, _, _, err := GenerateQRCode("testuser", "opengist.io", secret)
if err != nil {
t.Fatalf("Failed to generate secret string: %v", err)
}
// Generate a valid passcode for the current time
validPasscode, err := totp.GenerateCode(secretStr, time.Now())
if err != nil {
t.Fatalf("Failed to generate valid passcode: %v", err)
}
tests := []struct {
name string
passcode string
secret string
wantValid bool
}{
{
name: "valid passcode",
passcode: validPasscode,
secret: secretStr,
wantValid: true,
},
{
name: "invalid passcode - wrong digits",
passcode: "000000",
secret: secretStr,
wantValid: false,
},
{
name: "invalid passcode - wrong length",
passcode: "123",
secret: secretStr,
wantValid: false,
},
{
name: "empty passcode",
passcode: "",
secret: secretStr,
wantValid: false,
},
{
name: "empty secret",
passcode: validPasscode,
secret: "",
wantValid: false,
},
{
name: "invalid secret format",
passcode: validPasscode,
secret: "not-a-valid-base32-secret!@#",
wantValid: false,
},
{
name: "passcode with letters",
passcode: "12345A",
secret: secretStr,
wantValid: false,
},
{
name: "passcode with spaces",
passcode: "123 456",
secret: secretStr,
wantValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
valid := Validate(tt.passcode, tt.secret)
if valid != tt.wantValid {
t.Errorf("Validate() = %v, want %v", valid, tt.wantValid)
}
})
}
}
func TestValidate_TimeDrift(t *testing.T) {
// Generate a valid secret
secretStr, _, _, err := GenerateQRCode("testuser", "opengist.io", nil)
if err != nil {
t.Fatalf("Failed to generate secret: %v", err)
}
// Test that passcodes from previous and next time windows are accepted
// (TOTP typically accepts codes from ±1 time window for clock drift)
pastTime := time.Now().Add(-30 * time.Second)
futureTime := time.Now().Add(30 * time.Second)
pastPasscode, err := totp.GenerateCode(secretStr, pastTime)
if err != nil {
t.Fatalf("Failed to generate past passcode: %v", err)
}
futurePasscode, err := totp.GenerateCode(secretStr, futureTime)
if err != nil {
t.Fatalf("Failed to generate future passcode: %v", err)
}
// These should be valid due to time drift tolerance
if !Validate(pastPasscode, secretStr) {
t.Error("Validate() rejected passcode from previous time window")
}
if !Validate(futurePasscode, secretStr) {
t.Error("Validate() rejected passcode from next time window")
}
}
func TestValidate_ExpiredPasscode(t *testing.T) {
// Generate a valid secret
secretStr, _, _, err := GenerateQRCode("testuser", "opengist.io", nil)
if err != nil {
t.Fatalf("Failed to generate secret: %v", err)
}
// Generate a passcode from 2 minutes ago (should be expired)
oldTime := time.Now().Add(-2 * time.Minute)
oldPasscode, err := totp.GenerateCode(secretStr, oldTime)
if err != nil {
t.Fatalf("Failed to generate old passcode: %v", err)
}
// This should be invalid
if Validate(oldPasscode, secretStr) {
t.Error("Validate() accepted expired passcode from 2 minutes ago")
}
}
func TestValidate_RoundTrip(t *testing.T) {
// Test full round trip: generate secret, generate code, validate code
tests := []struct {
name string
username string
siteUrl string
}{
{
name: "basic round trip",
username: "testuser",
siteUrl: "opengist.io",
},
{
name: "round trip with dot in username",
username: "test.user",
siteUrl: "opengist.io",
},
{
name: "round trip with hyphen in username",
username: "test-user",
siteUrl: "opengist.io",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Generate QR code and secret
secretStr, _, _, err := GenerateQRCode(tt.username, tt.siteUrl, nil)
if err != nil {
t.Fatalf("GenerateQRCode() failed: %v", err)
}
// Generate a valid passcode
passcode, err := totp.GenerateCode(secretStr, time.Now())
if err != nil {
t.Fatalf("GenerateCode() failed: %v", err)
}
// Validate the passcode
if !Validate(passcode, secretStr) {
t.Error("Validate() rejected valid passcode")
}
// Validate wrong passcode fails
wrongPasscode := "000000"
if passcode == wrongPasscode {
wrongPasscode = "111111"
}
if Validate(wrongPasscode, secretStr) {
t.Error("Validate() accepted invalid passcode")
}
})
}
}
func TestGenerateSecret(t *testing.T) {
// Test the internal generateSecret function behavior through GenerateQRCode
for i := 0; i < 10; i++ {
_, _, secret, err := GenerateQRCode("testuser", "opengist.io", nil)
if err != nil {
t.Fatalf("Iteration %d: generateSecret() failed: %v", i, err)
}
if len(secret) != secretSize {
t.Errorf("Iteration %d: secret length = %d, want %d", i, len(secret), secretSize)
}
// Verify secret is not all zeros (extremely unlikely with crypto/rand)
allZeros := true
for _, b := range secret {
if b != 0 {
allZeros = false
break
}
}
if allZeros {
t.Errorf("Iteration %d: secret is all zeros", i)
}
}
}