From 3957dfb3ea27d3ce84b9295f28b46c55b18db7cd Mon Sep 17 00:00:00 2001 From: Thomas Miceli <27960254+thomiceli@users.noreply.github.com> Date: Fri, 31 Oct 2025 15:37:45 +0700 Subject: [PATCH] Add some tests (#553) --- internal/auth/password/argon2id_test.go | 427 +++++++++++++++++++++++ internal/auth/password/password_test.go | 193 +++++++++++ internal/auth/totp/aes_test.go | 430 +++++++++++++++++++++++ internal/auth/totp/totp_test.go | 431 ++++++++++++++++++++++++ 4 files changed, 1481 insertions(+) create mode 100644 internal/auth/password/argon2id_test.go create mode 100644 internal/auth/password/password_test.go create mode 100644 internal/auth/totp/aes_test.go create mode 100644 internal/auth/totp/totp_test.go diff --git a/internal/auth/password/argon2id_test.go b/internal/auth/password/argon2id_test.go new file mode 100644 index 0000000..803cbd2 --- /dev/null +++ b/internal/auth/password/argon2id_test.go @@ -0,0 +1,427 @@ +package password + +import ( + "encoding/base64" + "strings" + "testing" +) + +func TestArgon2ID_Hash(t *testing.T) { + tests := []struct { + name string + plain string + wantErr bool + }{ + { + name: "basic password", + plain: "password123", + wantErr: false, + }, + { + name: "empty string", + plain: "", + wantErr: false, + }, + { + name: "long password", + plain: strings.Repeat("a", 10000), + wantErr: false, + }, + { + name: "unicode password", + plain: "パスワード🔒", + wantErr: false, + }, + { + name: "special characters", + plain: "!@#$%^&*()_+-=[]{}|;:',.<>?/`~", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hash, err := Argon2id.Hash(tt.plain) + if (err != nil) != tt.wantErr { + t.Errorf("Argon2id.Hash() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr { + // Verify the hash format + if !strings.HasPrefix(hash, "$argon2id$") { + t.Errorf("Hash does not start with $argon2id$: %v", hash) + } + + // Verify all parts are present + parts := strings.Split(hash, "$") + if len(parts) != 6 { + t.Errorf("Hash has %d parts, expected 6: %v", len(parts), hash) + } + + // Verify salt is properly encoded + if len(parts) >= 5 { + _, err := base64.RawStdEncoding.DecodeString(parts[4]) + if err != nil { + t.Errorf("Salt is not properly base64 encoded: %v", err) + } + } + + // Verify hash is properly encoded + if len(parts) >= 6 { + _, err := base64.RawStdEncoding.DecodeString(parts[5]) + if err != nil { + t.Errorf("Hash is not properly base64 encoded: %v", err) + } + } + } + }) + } +} + +func TestArgon2ID_Verify(t *testing.T) { + // Generate a valid hash for testing + testPassword := "correctpassword" + validHash, err := Argon2id.Hash(testPassword) + if err != nil { + t.Fatalf("Failed to generate test hash: %v", err) + } + + tests := []struct { + name string + plain string + hash string + wantMatch bool + wantErr bool + }{ + { + name: "correct password", + plain: testPassword, + hash: validHash, + wantMatch: true, + wantErr: false, + }, + { + name: "incorrect password", + plain: "wrongpassword", + hash: validHash, + wantMatch: false, + wantErr: false, + }, + { + name: "empty password", + plain: "", + hash: validHash, + wantMatch: false, + wantErr: false, + }, + { + name: "empty hash", + plain: testPassword, + hash: "", + wantMatch: false, + wantErr: false, + }, + { + name: "invalid hash - too few parts", + plain: testPassword, + hash: "$argon2id$v=19$m=65536", + wantMatch: false, + wantErr: true, + }, + { + name: "invalid hash - too many parts", + plain: testPassword, + hash: "$argon2id$v=19$m=65536,t=1,p=4$salt$hash$extra", + wantMatch: false, + wantErr: true, + }, + { + name: "invalid hash - malformed parameters", + plain: testPassword, + hash: "$argon2id$v=19$invalid$salt$hash", + wantMatch: false, + wantErr: true, + }, + { + name: "invalid hash - bad base64 salt", + plain: testPassword, + hash: "$argon2id$v=19$m=65536,t=1,p=4$not-valid-base64!@#$hash", + wantMatch: false, + wantErr: true, + }, + { + name: "invalid hash - bad base64 hash", + plain: testPassword, + hash: "$argon2id$v=19$m=65536,t=1,p=4$dGVzdA$not-valid-base64!@#", + wantMatch: false, + wantErr: true, + }, + { + name: "wrong algorithm prefix", + plain: testPassword, + hash: "$bcrypt$rounds=10$saltsaltsaltsaltsalt", + wantMatch: false, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + match, err := Argon2id.Verify(tt.plain, tt.hash) + if (err != nil) != tt.wantErr { + t.Errorf("Argon2id.Verify() error = %v, wantErr %v", err, tt.wantErr) + return + } + if match != tt.wantMatch { + t.Errorf("Argon2id.Verify() match = %v, wantMatch %v", match, tt.wantMatch) + } + }) + } +} + +func TestArgon2ID_SaltUniqueness(t *testing.T) { + password := "testpassword" + iterations := 10 + + hashes := make(map[string]bool) + salts := make(map[string]bool) + + for i := 0; i < iterations; i++ { + hash, err := Argon2id.Hash(password) + if err != nil { + t.Fatalf("Hash iteration %d failed: %v", i, err) + } + + // Check hash uniqueness + if hashes[hash] { + t.Errorf("Duplicate hash generated at iteration %d", i) + } + hashes[hash] = true + + // Extract and check salt uniqueness + parts := strings.Split(hash, "$") + if len(parts) >= 5 { + salt := parts[4] + if salts[salt] { + t.Errorf("Duplicate salt generated at iteration %d", i) + } + salts[salt] = true + } + + // Verify each hash works + match, err := Argon2id.Verify(password, hash) + if err != nil || !match { + t.Errorf("Hash %d failed verification: err=%v, match=%v", i, err, match) + } + } +} + +func TestArgon2ID_HashFormat(t *testing.T) { + password := "testformat" + hash, err := Argon2id.Hash(password) + if err != nil { + t.Fatalf("Hash failed: %v", err) + } + + parts := strings.Split(hash, "$") + if len(parts) != 6 { + t.Fatalf("Expected 6 parts, got %d: %v", len(parts), hash) + } + + // Part 0 should be empty (before first $) + if parts[0] != "" { + t.Errorf("Part 0 should be empty, got: %v", parts[0]) + } + + // Part 1 should be "argon2id" + if parts[1] != "argon2id" { + t.Errorf("Part 1 should be 'argon2id', got: %v", parts[1]) + } + + // Part 2 should be version + if !strings.HasPrefix(parts[2], "v=") { + t.Errorf("Part 2 should start with 'v=', got: %v", parts[2]) + } + + // Part 3 should be parameters + if !strings.Contains(parts[3], "m=") || !strings.Contains(parts[3], "t=") || !strings.Contains(parts[3], "p=") { + t.Errorf("Part 3 should contain m=, t=, and p=, got: %v", parts[3]) + } + + // Part 4 should be base64 encoded salt + salt, err := base64.RawStdEncoding.DecodeString(parts[4]) + if err != nil { + t.Errorf("Salt (part 4) is not valid base64: %v", err) + } + if len(salt) != int(Argon2id.saltLen) { + t.Errorf("Salt length is %d, expected %d", len(salt), Argon2id.saltLen) + } + + // Part 5 should be base64 encoded hash + decodedHash, err := base64.RawStdEncoding.DecodeString(parts[5]) + if err != nil { + t.Errorf("Hash (part 5) is not valid base64: %v", err) + } + if len(decodedHash) != int(Argon2id.keyLen) { + t.Errorf("Hash length is %d, expected %d", len(decodedHash), Argon2id.keyLen) + } +} + +func TestArgon2ID_CaseModification(t *testing.T) { + // Passwords should be case-sensitive + password := "TestPassword" + hash, err := Argon2id.Hash(password) + if err != nil { + t.Fatalf("Hash failed: %v", err) + } + + // Correct case should match + match, err := Argon2id.Verify(password, hash) + if err != nil || !match { + t.Errorf("Correct password failed: err=%v, match=%v", err, match) + } + + // Wrong case should not match + match, err = Argon2id.Verify("testpassword", hash) + if err != nil { + t.Errorf("Verify returned error: %v", err) + } + if match { + t.Error("Password verification should be case-sensitive") + } + + match, err = Argon2id.Verify("TESTPASSWORD", hash) + if err != nil { + t.Errorf("Verify returned error: %v", err) + } + if match { + t.Error("Password verification should be case-sensitive") + } +} + +func TestArgon2ID_InvalidParameters(t *testing.T) { + password := "testpassword" + + tests := []struct { + name string + hash string + wantErr bool + }{ + { + name: "negative memory parameter", + hash: "$argon2id$v=19$m=-1,t=1,p=4$dGVzdHNhbHQ$testhash", + wantErr: true, + }, + { + name: "negative time parameter", + hash: "$argon2id$v=19$m=65536,t=-1,p=4$dGVzdHNhbHQ$testhash", + wantErr: true, + }, + { + name: "negative parallelism parameter", + hash: "$argon2id$v=19$m=65536,t=1,p=-4$dGVzdHNhbHQ$testhash", + wantErr: true, + }, + { + name: "zero memory parameter", + hash: "$argon2id$v=19$m=0,t=1,p=4$dGVzdHNhbHQ$testhash", + wantErr: false, // argon2 may handle this, we just test parsing + }, + { + name: "missing parameter value", + hash: "$argon2id$v=19$m=,t=1,p=4$dGVzdHNhbHQ$testhash", + wantErr: true, + }, + { + name: "non-numeric parameter", + hash: "$argon2id$v=19$m=abc,t=1,p=4$dGVzdHNhbHQ$testhash", + wantErr: true, + }, + { + name: "missing parameters separator", + hash: "$argon2id$v=19$m=65536 t=1 p=4$dGVzdHNhbHQ$testhash", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := Argon2id.Verify(password, tt.hash) + if (err != nil) != tt.wantErr { + t.Errorf("Argon2id.Verify() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestArgon2ID_ConcurrentHashing(t *testing.T) { + password := "testpassword" + concurrency := 10 + + type result struct { + hash string + err error + } + + results := make(chan result, concurrency) + + // Generate hashes concurrently + for i := 0; i < concurrency; i++ { + go func() { + hash, err := Argon2id.Hash(password) + results <- result{hash: hash, err: err} + }() + } + + // Collect results + hashes := make(map[string]bool) + for i := 0; i < concurrency; i++ { + res := <-results + if res.err != nil { + t.Errorf("Concurrent hash %d failed: %v", i, res.err) + continue + } + + // Check for duplicates + if hashes[res.hash] { + t.Errorf("Duplicate hash generated in concurrent test") + } + hashes[res.hash] = true + + // Verify each hash works + match, err := Argon2id.Verify(password, res.hash) + if err != nil || !match { + t.Errorf("Hash %d failed verification: err=%v, match=%v", i, err, match) + } + } +} + +func TestArgon2ID_VeryLongPassword(t *testing.T) { + // Test with extremely long password (100KB) + password := strings.Repeat("a", 100*1024) + + hash, err := Argon2id.Hash(password) + if err != nil { + t.Fatalf("Failed to hash very long password: %v", err) + } + + match, err := Argon2id.Verify(password, hash) + if err != nil { + t.Fatalf("Failed to verify very long password: %v", err) + } + + if !match { + t.Error("Very long password failed verification") + } + + // Verify wrong password still fails + wrongPassword := strings.Repeat("b", 100*1024) + match, err = Argon2id.Verify(wrongPassword, hash) + if err != nil { + t.Errorf("Verify returned error: %v", err) + } + if match { + t.Error("Wrong very long password should not match") + } +} diff --git a/internal/auth/password/password_test.go b/internal/auth/password/password_test.go new file mode 100644 index 0000000..1fcb17a --- /dev/null +++ b/internal/auth/password/password_test.go @@ -0,0 +1,193 @@ +package password + +import ( + "strings" + "testing" +) + +func TestHashPassword(t *testing.T) { + tests := []struct { + name string + password string + wantErr bool + }{ + { + name: "simple password", + password: "password123", + wantErr: false, + }, + { + name: "empty password", + password: "", + wantErr: false, + }, + { + name: "long password", + password: strings.Repeat("a", 1000), + wantErr: false, + }, + { + name: "special characters", + password: "p@ssw0rd!#$%^&*()", + wantErr: false, + }, + { + name: "unicode characters", + password: "パスワード123", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hash, err := HashPassword(tt.password) + if (err != nil) != tt.wantErr { + t.Errorf("HashPassword() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + // Verify hash format + if !strings.HasPrefix(hash, "$argon2id$") { + t.Errorf("HashPassword() returned invalid hash format: %v", hash) + } + // Verify hash has correct number of parts + parts := strings.Split(hash, "$") + if len(parts) != 6 { + t.Errorf("HashPassword() returned hash with incorrect number of parts: %v", len(parts)) + } + } + }) + } +} + +func TestVerifyPassword(t *testing.T) { + // Pre-generate a known hash for testing + testPassword := "testpassword123" + testHash, err := HashPassword(testPassword) + if err != nil { + t.Fatalf("Failed to generate test hash: %v", err) + } + + tests := []struct { + name string + password string + hash string + wantMatch bool + wantErr bool + }{ + { + name: "correct password", + password: testPassword, + hash: testHash, + wantMatch: true, + wantErr: false, + }, + { + name: "incorrect password", + password: "wrongpassword", + hash: testHash, + wantMatch: false, + wantErr: false, + }, + { + name: "empty password against valid hash", + password: "", + hash: testHash, + wantMatch: false, + wantErr: false, + }, + { + name: "empty hash", + password: testPassword, + hash: "", + wantMatch: false, + wantErr: false, + }, + { + name: "invalid hash format", + password: testPassword, + hash: "invalid", + wantMatch: false, + wantErr: true, + }, + { + name: "malformed hash - wrong prefix", + password: testPassword, + hash: "$bcrypt$invalid$hash", + wantMatch: false, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + match, err := VerifyPassword(tt.password, tt.hash) + if (err != nil) != tt.wantErr { + t.Errorf("VerifyPassword() error = %v, wantErr %v", err, tt.wantErr) + return + } + if match != tt.wantMatch { + t.Errorf("VerifyPassword() match = %v, wantMatch %v", match, tt.wantMatch) + } + }) + } +} + +func TestHashPasswordUniqueness(t *testing.T) { + password := "testpassword" + + // Generate multiple hashes of the same password + hash1, err := HashPassword(password) + if err != nil { + t.Fatalf("Failed to hash password: %v", err) + } + + hash2, err := HashPassword(password) + if err != nil { + t.Fatalf("Failed to hash password: %v", err) + } + + // Hashes should be different due to different salts + if hash1 == hash2 { + t.Error("HashPassword() should generate unique hashes for the same password") + } + + // But both should verify correctly + match1, err := VerifyPassword(password, hash1) + if err != nil || !match1 { + t.Errorf("Failed to verify first hash: err=%v, match=%v", err, match1) + } + + match2, err := VerifyPassword(password, hash2) + if err != nil || !match2 { + t.Errorf("Failed to verify second hash: err=%v, match=%v", err, match2) + } +} + +func TestPasswordRoundTrip(t *testing.T) { + tests := []string{ + "simple", + "with spaces and special chars !@#$%", + "パスワード", + strings.Repeat("long", 100), + "", + } + + for _, password := range tests { + t.Run(password, func(t *testing.T) { + hash, err := HashPassword(password) + if err != nil { + t.Fatalf("HashPassword() failed: %v", err) + } + + match, err := VerifyPassword(password, hash) + if err != nil { + t.Fatalf("VerifyPassword() failed: %v", err) + } + + if !match { + t.Error("Password round trip failed: hashed password does not verify") + } + }) + } +} \ No newline at end of file diff --git a/internal/auth/totp/aes_test.go b/internal/auth/totp/aes_test.go new file mode 100644 index 0000000..03139b5 --- /dev/null +++ b/internal/auth/totp/aes_test.go @@ -0,0 +1,430 @@ +package totp + +import ( + "bytes" + "crypto/aes" + "testing" +) + +func TestAESEncrypt(t *testing.T) { + tests := []struct { + name string + key []byte + text []byte + wantErr bool + }{ + { + name: "basic encryption with 16-byte key", + key: []byte("1234567890123456"), // 16 bytes (AES-128) + text: []byte("hello world"), + wantErr: false, + }, + { + name: "basic encryption with 24-byte key", + key: []byte("123456789012345678901234"), // 24 bytes (AES-192) + text: []byte("hello world"), + wantErr: false, + }, + { + name: "basic encryption with 32-byte key", + key: []byte("12345678901234567890123456789012"), // 32 bytes (AES-256) + text: []byte("hello world"), + wantErr: false, + }, + { + name: "empty text", + key: []byte("1234567890123456"), + text: []byte(""), + wantErr: false, + }, + { + name: "long text", + key: []byte("1234567890123456"), + text: []byte("This is a much longer text that spans multiple blocks and should be encrypted properly without any issues"), + wantErr: false, + }, + { + name: "binary data", + key: []byte("1234567890123456"), + text: []byte{0x00, 0x01, 0x02, 0x03, 0xFF, 0xFE, 0xFD}, + wantErr: false, + }, + { + name: "invalid key length - too short", + key: []byte("short"), + text: []byte("hello world"), + wantErr: true, + }, + { + name: "invalid key length - 17 bytes", + key: []byte("12345678901234567"), // 17 bytes (invalid) + text: []byte("hello world"), + wantErr: true, + }, + { + name: "nil key", + key: nil, + text: []byte("hello world"), + wantErr: true, + }, + { + name: "empty key", + key: []byte(""), + text: []byte("hello world"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ciphertext, err := AESEncrypt(tt.key, tt.text) + if (err != nil) != tt.wantErr { + t.Errorf("AESEncrypt() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr { + // Verify ciphertext is not empty + if len(ciphertext) == 0 { + t.Error("AESEncrypt() returned empty ciphertext") + } + + // Verify ciphertext length is correct (IV + encrypted text) + expectedLen := aes.BlockSize + len(tt.text) + if len(ciphertext) != expectedLen { + t.Errorf("AESEncrypt() ciphertext length = %d, want %d", len(ciphertext), expectedLen) + } + + // Verify ciphertext is different from plaintext (unless text is empty) + if len(tt.text) > 0 && bytes.Equal(ciphertext[aes.BlockSize:], tt.text) { + t.Error("AESEncrypt() ciphertext matches plaintext") + } + + // Verify IV is present and non-zero + iv := ciphertext[:aes.BlockSize] + allZeros := true + for _, b := range iv { + if b != 0 { + allZeros = false + break + } + } + if allZeros { + t.Error("AESEncrypt() IV is all zeros") + } + } + }) + } +} + +func TestAESDecrypt(t *testing.T) { + validKey := []byte("1234567890123456") + validText := []byte("hello world") + + // Encrypt some data to use for valid test cases + validCiphertext, err := AESEncrypt(validKey, validText) + if err != nil { + t.Fatalf("Failed to create valid ciphertext: %v", err) + } + + tests := []struct { + name string + key []byte + ciphertext []byte + wantErr bool + }{ + { + name: "valid decryption", + key: validKey, + ciphertext: validCiphertext, + wantErr: false, + }, + { + name: "ciphertext too short - empty", + key: validKey, + ciphertext: []byte(""), + wantErr: true, + }, + { + name: "ciphertext too short - less than block size", + key: validKey, + ciphertext: []byte("short"), + wantErr: true, + }, + { + name: "ciphertext exactly block size (IV only, no data)", + key: validKey, + ciphertext: make([]byte, aes.BlockSize), + wantErr: false, + }, + { + name: "invalid key length", + key: []byte("short"), + ciphertext: validCiphertext, + wantErr: true, + }, + { + name: "wrong key", + key: []byte("6543210987654321"), + ciphertext: validCiphertext, + wantErr: false, // Decryption succeeds but produces garbage + }, + { + name: "nil key", + key: nil, + ciphertext: validCiphertext, + wantErr: true, + }, + { + name: "nil ciphertext", + key: validKey, + ciphertext: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + plaintext, err := AESDecrypt(tt.key, tt.ciphertext) + if (err != nil) != tt.wantErr { + t.Errorf("AESDecrypt() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr { + // For valid decryption with correct key, verify we get original text + if tt.name == "valid decryption" && !bytes.Equal(plaintext, validText) { + t.Errorf("AESDecrypt() plaintext = %v, want %v", plaintext, validText) + } + + // For ciphertext with only IV, plaintext should be empty + if tt.name == "ciphertext exactly block size (IV only, no data)" && len(plaintext) != 0 { + t.Errorf("AESDecrypt() plaintext length = %d, want 0", len(plaintext)) + } + } + }) + } +} + +func TestAESEncryptDecrypt_RoundTrip(t *testing.T) { + tests := []struct { + name string + key []byte + text []byte + }{ + { + name: "basic round trip", + key: []byte("1234567890123456"), + text: []byte("hello world"), + }, + { + name: "empty text round trip", + key: []byte("1234567890123456"), + text: []byte(""), + }, + { + name: "long text round trip", + key: []byte("1234567890123456"), + text: []byte("This is a very long text that contains multiple blocks of data and should be encrypted and decrypted correctly without any data loss or corruption"), + }, + { + name: "binary data round trip", + key: []byte("1234567890123456"), + text: []byte{0x00, 0x01, 0x02, 0x03, 0xFF, 0xFE, 0xFD, 0xFC}, + }, + { + name: "unicode text round trip", + key: []byte("1234567890123456"), + text: []byte("Hello 世界! 🔐 Encryption"), + }, + { + name: "AES-192 round trip", + key: []byte("123456789012345678901234"), + text: []byte("testing AES-192"), + }, + { + name: "AES-256 round trip", + key: []byte("12345678901234567890123456789012"), + text: []byte("testing AES-256"), + }, + { + name: "special characters", + key: []byte("1234567890123456"), + text: []byte("!@#$%^&*()_+-=[]{}|;':\",./<>?"), + }, + { + name: "newlines and tabs", + key: []byte("1234567890123456"), + text: []byte("line1\nline2\tline3\r\nline4"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Encrypt + ciphertext, err := AESEncrypt(tt.key, tt.text) + if err != nil { + t.Fatalf("AESEncrypt() failed: %v", err) + } + + // Decrypt + plaintext, err := AESDecrypt(tt.key, ciphertext) + if err != nil { + t.Fatalf("AESDecrypt() failed: %v", err) + } + + // Verify plaintext matches original + if !bytes.Equal(plaintext, tt.text) { + t.Errorf("Round trip failed: got %v, want %v", plaintext, tt.text) + } + }) + } +} + +func TestAESEncrypt_Uniqueness(t *testing.T) { + key := []byte("1234567890123456") + text := []byte("hello world") + iterations := 10 + + ciphertexts := make(map[string]bool) + + for i := 0; i < iterations; i++ { + ciphertext, err := AESEncrypt(key, text) + if err != nil { + t.Fatalf("Iteration %d failed: %v", i, err) + } + + // Each encryption should produce different ciphertext (due to random IV) + ciphertextStr := string(ciphertext) + if ciphertexts[ciphertextStr] { + t.Errorf("Duplicate ciphertext generated at iteration %d", i) + } + ciphertexts[ciphertextStr] = true + + // But all should decrypt to the same plaintext + plaintext, err := AESDecrypt(key, ciphertext) + if err != nil { + t.Fatalf("Iteration %d decryption failed: %v", i, err) + } + if !bytes.Equal(plaintext, text) { + t.Errorf("Iteration %d: decrypted text doesn't match original", i) + } + } +} + +func TestAESEncrypt_IVUniqueness(t *testing.T) { + key := []byte("1234567890123456") + text := []byte("test data") + iterations := 20 + + ivs := make(map[string]bool) + + for i := 0; i < iterations; i++ { + ciphertext, err := AESEncrypt(key, text) + if err != nil { + t.Fatalf("Iteration %d failed: %v", i, err) + } + + // Extract IV (first block) + iv := ciphertext[:aes.BlockSize] + ivStr := string(iv) + + // Each IV should be unique + if ivs[ivStr] { + t.Errorf("Duplicate IV generated at iteration %d", i) + } + ivs[ivStr] = true + } +} + +func TestAESDecrypt_WrongKey(t *testing.T) { + originalKey := []byte("1234567890123456") + wrongKey := []byte("6543210987654321") + text := []byte("secret message") + + // Encrypt with original key + ciphertext, err := AESEncrypt(originalKey, text) + if err != nil { + t.Fatalf("AESEncrypt() failed: %v", err) + } + + // Decrypt with wrong key - should not error but produce wrong plaintext + plaintext, err := AESDecrypt(wrongKey, ciphertext) + if err != nil { + t.Fatalf("AESDecrypt() with wrong key failed: %v", err) + } + + // Plaintext should be different from original + if bytes.Equal(plaintext, text) { + t.Error("AESDecrypt() with wrong key produced correct plaintext") + } +} + +func TestAESDecrypt_CorruptedCiphertext(t *testing.T) { + key := []byte("1234567890123456") + text := []byte("hello world") + + // Encrypt + ciphertext, err := AESEncrypt(key, text) + if err != nil { + t.Fatalf("AESEncrypt() failed: %v", err) + } + + // Corrupt the ciphertext (flip a bit in the encrypted data, not the IV) + if len(ciphertext) > aes.BlockSize { + corruptedCiphertext := make([]byte, len(ciphertext)) + copy(corruptedCiphertext, ciphertext) + corruptedCiphertext[aes.BlockSize] ^= 0xFF + + // Decrypt corrupted ciphertext - should not error but produce wrong plaintext + plaintext, err := AESDecrypt(key, corruptedCiphertext) + if err != nil { + t.Fatalf("AESDecrypt() with corrupted ciphertext failed: %v", err) + } + + // Plaintext should be different from original + if bytes.Equal(plaintext, text) { + t.Error("AESDecrypt() with corrupted ciphertext produced correct plaintext") + } + } +} + +func TestAESEncryptDecrypt_DifferentKeySizes(t *testing.T) { + tests := []struct { + name string + keySize int + }{ + {"AES-128", 16}, + {"AES-192", 24}, + {"AES-256", 32}, + } + + text := []byte("test message for different key sizes") + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Generate key of specified size + key := make([]byte, tt.keySize) + for i := range key { + key[i] = byte(i) + } + + // Encrypt + ciphertext, err := AESEncrypt(key, text) + if err != nil { + t.Fatalf("AESEncrypt() failed: %v", err) + } + + // Decrypt + plaintext, err := AESDecrypt(key, ciphertext) + if err != nil { + t.Fatalf("AESDecrypt() failed: %v", err) + } + + // Verify + if !bytes.Equal(plaintext, text) { + t.Errorf("Round trip failed for %s", tt.name) + } + }) + } +} diff --git a/internal/auth/totp/totp_test.go b/internal/auth/totp/totp_test.go new file mode 100644 index 0000000..72f97c1 --- /dev/null +++ b/internal/auth/totp/totp_test.go @@ -0,0 +1,431 @@ +package totp + +import ( + "encoding/base64" + "strings" + "sync" + "testing" + "time" + + "github.com/pquerna/otp/totp" +) + +func TestGenerateQRCode(t *testing.T) { + tests := []struct { + name string + username string + siteUrl string + secret []byte + wantErr bool + }{ + { + name: "basic generation with nil secret", + username: "testuser", + siteUrl: "opengist.io", + secret: nil, + wantErr: false, + }, + { + name: "basic generation with provided secret", + username: "testuser", + siteUrl: "opengist.io", + secret: []byte("1234567890123456"), + wantErr: false, + }, + { + name: "username with special characters", + username: "test.user", + siteUrl: "opengist.io", + secret: nil, + wantErr: false, + }, + { + name: "site URL with protocol and port", + username: "testuser", + siteUrl: "https://opengist.io:6157", + secret: nil, + wantErr: false, + }, + { + name: "empty username", + username: "", + siteUrl: "opengist.io", + secret: nil, + wantErr: true, + }, + { + name: "empty site URL", + username: "testuser", + siteUrl: "", + secret: nil, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + secretStr, qrcode, secretBytes, err := GenerateQRCode(tt.username, tt.siteUrl, tt.secret) + if (err != nil) != tt.wantErr { + t.Errorf("GenerateQRCode() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr { + // Verify secret string is not empty + if secretStr == "" { + t.Error("GenerateQRCode() returned empty secret string") + } + + // Verify QR code image is generated + if qrcode == "" { + t.Error("GenerateQRCode() returned empty QR code") + } + + // Verify QR code has correct data URI prefix + if !strings.HasPrefix(string(qrcode), "data:image/png;base64,") { + t.Errorf("QR code does not have correct data URI prefix: %s", qrcode[:50]) + } + + // Verify QR code is valid base64 after prefix + base64Data := strings.TrimPrefix(string(qrcode), "data:image/png;base64,") + _, err := base64.StdEncoding.DecodeString(base64Data) + if err != nil { + t.Errorf("QR code base64 data is invalid: %v", err) + } + + // Verify secret bytes are returned + if secretBytes == nil { + t.Error("GenerateQRCode() returned nil secret bytes") + } + + // Verify secret bytes have correct length + if len(secretBytes) != secretSize { + t.Errorf("Secret bytes length = %d, want %d", len(secretBytes), secretSize) + } + + // If a secret was provided, verify it matches what was returned + if tt.secret != nil && string(secretBytes) != string(tt.secret) { + t.Error("Returned secret bytes do not match provided secret") + } + } + }) + } +} + +func TestGenerateQRCode_SecretUniqueness(t *testing.T) { + username := "testuser" + siteUrl := "opengist.io" + iterations := 10 + + secrets := make(map[string]bool) + secretBytes := make(map[string]bool) + + for i := 0; i < iterations; i++ { + secretStr, _, secret, err := GenerateQRCode(username, siteUrl, nil) + if err != nil { + t.Fatalf("Iteration %d failed: %v", i, err) + } + + // Check secret string uniqueness + if secrets[secretStr] { + t.Errorf("Duplicate secret string generated at iteration %d", i) + } + secrets[secretStr] = true + + // Check secret bytes uniqueness + secretKey := string(secret) + if secretBytes[secretKey] { + t.Errorf("Duplicate secret bytes generated at iteration %d", i) + } + secretBytes[secretKey] = true + } +} + +func TestGenerateQRCode_WithProvidedSecret(t *testing.T) { + username := "testuser" + siteUrl := "opengist.io" + providedSecret := []byte("mysecret12345678") + + // Generate QR code multiple times with the same secret + secretStr1, _, secret1, err := GenerateQRCode(username, siteUrl, providedSecret) + if err != nil { + t.Fatalf("First generation failed: %v", err) + } + + secretStr2, _, secret2, err := GenerateQRCode(username, siteUrl, providedSecret) + if err != nil { + t.Fatalf("Second generation failed: %v", err) + } + + // Secret strings should be the same when using the same input secret + if secretStr1 != secretStr2 { + t.Error("Secret strings differ when using the same provided secret") + } + + // Secret bytes should match the provided secret + if string(secret1) != string(providedSecret) { + t.Error("Returned secret bytes do not match provided secret (first call)") + } + if string(secret2) != string(providedSecret) { + t.Error("Returned secret bytes do not match provided secret (second call)") + } +} + +func TestGenerateQRCode_ConcurrentGeneration(t *testing.T) { + username := "testuser" + siteUrl := "opengist.io" + concurrency := 10 + + type result struct { + secretStr string + secretBytes []byte + err error + } + + results := make(chan result, concurrency) + var wg sync.WaitGroup + + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + secretStr, _, secretBytes, err := GenerateQRCode(username, siteUrl, nil) + results <- result{secretStr: secretStr, secretBytes: secretBytes, err: err} + }() + } + + wg.Wait() + close(results) + + secrets := make(map[string]bool) + for res := range results { + if res.err != nil { + t.Errorf("Concurrent generation failed: %v", res.err) + continue + } + + // Check for duplicates + if secrets[res.secretStr] { + t.Error("Duplicate secret generated in concurrent test") + } + secrets[res.secretStr] = true + } +} + +func TestValidate(t *testing.T) { + // Generate a valid secret for testing + _, _, secret, err := GenerateQRCode("testuser", "opengist.io", nil) + if err != nil { + t.Fatalf("Failed to generate secret: %v", err) + } + + // Convert secret bytes to base32 string for TOTP + secretStr, _, _, err := GenerateQRCode("testuser", "opengist.io", secret) + if err != nil { + t.Fatalf("Failed to generate secret string: %v", err) + } + + // Generate a valid passcode for the current time + validPasscode, err := totp.GenerateCode(secretStr, time.Now()) + if err != nil { + t.Fatalf("Failed to generate valid passcode: %v", err) + } + + tests := []struct { + name string + passcode string + secret string + wantValid bool + }{ + { + name: "valid passcode", + passcode: validPasscode, + secret: secretStr, + wantValid: true, + }, + { + name: "invalid passcode - wrong digits", + passcode: "000000", + secret: secretStr, + wantValid: false, + }, + { + name: "invalid passcode - wrong length", + passcode: "123", + secret: secretStr, + wantValid: false, + }, + { + name: "empty passcode", + passcode: "", + secret: secretStr, + wantValid: false, + }, + { + name: "empty secret", + passcode: validPasscode, + secret: "", + wantValid: false, + }, + { + name: "invalid secret format", + passcode: validPasscode, + secret: "not-a-valid-base32-secret!@#", + wantValid: false, + }, + { + name: "passcode with letters", + passcode: "12345A", + secret: secretStr, + wantValid: false, + }, + { + name: "passcode with spaces", + passcode: "123 456", + secret: secretStr, + wantValid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + valid := Validate(tt.passcode, tt.secret) + if valid != tt.wantValid { + t.Errorf("Validate() = %v, want %v", valid, tt.wantValid) + } + }) + } +} + +func TestValidate_TimeDrift(t *testing.T) { + // Generate a valid secret + secretStr, _, _, err := GenerateQRCode("testuser", "opengist.io", nil) + if err != nil { + t.Fatalf("Failed to generate secret: %v", err) + } + + // Test that passcodes from previous and next time windows are accepted + // (TOTP typically accepts codes from ±1 time window for clock drift) + pastTime := time.Now().Add(-30 * time.Second) + futureTime := time.Now().Add(30 * time.Second) + + pastPasscode, err := totp.GenerateCode(secretStr, pastTime) + if err != nil { + t.Fatalf("Failed to generate past passcode: %v", err) + } + + futurePasscode, err := totp.GenerateCode(secretStr, futureTime) + if err != nil { + t.Fatalf("Failed to generate future passcode: %v", err) + } + + // These should be valid due to time drift tolerance + if !Validate(pastPasscode, secretStr) { + t.Error("Validate() rejected passcode from previous time window") + } + + if !Validate(futurePasscode, secretStr) { + t.Error("Validate() rejected passcode from next time window") + } +} + +func TestValidate_ExpiredPasscode(t *testing.T) { + // Generate a valid secret + secretStr, _, _, err := GenerateQRCode("testuser", "opengist.io", nil) + if err != nil { + t.Fatalf("Failed to generate secret: %v", err) + } + + // Generate a passcode from 2 minutes ago (should be expired) + oldTime := time.Now().Add(-2 * time.Minute) + oldPasscode, err := totp.GenerateCode(secretStr, oldTime) + if err != nil { + t.Fatalf("Failed to generate old passcode: %v", err) + } + + // This should be invalid + if Validate(oldPasscode, secretStr) { + t.Error("Validate() accepted expired passcode from 2 minutes ago") + } +} + +func TestValidate_RoundTrip(t *testing.T) { + // Test full round trip: generate secret, generate code, validate code + tests := []struct { + name string + username string + siteUrl string + }{ + { + name: "basic round trip", + username: "testuser", + siteUrl: "opengist.io", + }, + { + name: "round trip with dot in username", + username: "test.user", + siteUrl: "opengist.io", + }, + { + name: "round trip with hyphen in username", + username: "test-user", + siteUrl: "opengist.io", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Generate QR code and secret + secretStr, _, _, err := GenerateQRCode(tt.username, tt.siteUrl, nil) + if err != nil { + t.Fatalf("GenerateQRCode() failed: %v", err) + } + + // Generate a valid passcode + passcode, err := totp.GenerateCode(secretStr, time.Now()) + if err != nil { + t.Fatalf("GenerateCode() failed: %v", err) + } + + // Validate the passcode + if !Validate(passcode, secretStr) { + t.Error("Validate() rejected valid passcode") + } + + // Validate wrong passcode fails + wrongPasscode := "000000" + if passcode == wrongPasscode { + wrongPasscode = "111111" + } + if Validate(wrongPasscode, secretStr) { + t.Error("Validate() accepted invalid passcode") + } + }) + } +} + +func TestGenerateSecret(t *testing.T) { + // Test the internal generateSecret function behavior through GenerateQRCode + for i := 0; i < 10; i++ { + _, _, secret, err := GenerateQRCode("testuser", "opengist.io", nil) + if err != nil { + t.Fatalf("Iteration %d: generateSecret() failed: %v", i, err) + } + + if len(secret) != secretSize { + t.Errorf("Iteration %d: secret length = %d, want %d", i, len(secret), secretSize) + } + + // Verify secret is not all zeros (extremely unlikely with crypto/rand) + allZeros := true + for _, b := range secret { + if b != 0 { + allZeros = false + break + } + } + if allZeros { + t.Errorf("Iteration %d: secret is all zeros", i) + } + } +}