Add some tests (#553)
This commit is contained in:
427
internal/auth/password/argon2id_test.go
Normal file
427
internal/auth/password/argon2id_test.go
Normal file
@@ -0,0 +1,427 @@
|
||||
package password
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestArgon2ID_Hash(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
plain string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic password",
|
||||
plain: "password123",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
plain: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "long password",
|
||||
plain: strings.Repeat("a", 10000),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "unicode password",
|
||||
plain: "パスワード🔒",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "special characters",
|
||||
plain: "!@#$%^&*()_+-=[]{}|;:',.<>?/`~",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hash, err := Argon2id.Hash(tt.plain)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Argon2id.Hash() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr {
|
||||
// Verify the hash format
|
||||
if !strings.HasPrefix(hash, "$argon2id$") {
|
||||
t.Errorf("Hash does not start with $argon2id$: %v", hash)
|
||||
}
|
||||
|
||||
// Verify all parts are present
|
||||
parts := strings.Split(hash, "$")
|
||||
if len(parts) != 6 {
|
||||
t.Errorf("Hash has %d parts, expected 6: %v", len(parts), hash)
|
||||
}
|
||||
|
||||
// Verify salt is properly encoded
|
||||
if len(parts) >= 5 {
|
||||
_, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
t.Errorf("Salt is not properly base64 encoded: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify hash is properly encoded
|
||||
if len(parts) >= 6 {
|
||||
_, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
t.Errorf("Hash is not properly base64 encoded: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgon2ID_Verify(t *testing.T) {
|
||||
// Generate a valid hash for testing
|
||||
testPassword := "correctpassword"
|
||||
validHash, err := Argon2id.Hash(testPassword)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate test hash: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
plain string
|
||||
hash string
|
||||
wantMatch bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "correct password",
|
||||
plain: testPassword,
|
||||
hash: validHash,
|
||||
wantMatch: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "incorrect password",
|
||||
plain: "wrongpassword",
|
||||
hash: validHash,
|
||||
wantMatch: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty password",
|
||||
plain: "",
|
||||
hash: validHash,
|
||||
wantMatch: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty hash",
|
||||
plain: testPassword,
|
||||
hash: "",
|
||||
wantMatch: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid hash - too few parts",
|
||||
plain: testPassword,
|
||||
hash: "$argon2id$v=19$m=65536",
|
||||
wantMatch: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid hash - too many parts",
|
||||
plain: testPassword,
|
||||
hash: "$argon2id$v=19$m=65536,t=1,p=4$salt$hash$extra",
|
||||
wantMatch: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid hash - malformed parameters",
|
||||
plain: testPassword,
|
||||
hash: "$argon2id$v=19$invalid$salt$hash",
|
||||
wantMatch: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid hash - bad base64 salt",
|
||||
plain: testPassword,
|
||||
hash: "$argon2id$v=19$m=65536,t=1,p=4$not-valid-base64!@#$hash",
|
||||
wantMatch: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid hash - bad base64 hash",
|
||||
plain: testPassword,
|
||||
hash: "$argon2id$v=19$m=65536,t=1,p=4$dGVzdA$not-valid-base64!@#",
|
||||
wantMatch: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong algorithm prefix",
|
||||
plain: testPassword,
|
||||
hash: "$bcrypt$rounds=10$saltsaltsaltsaltsalt",
|
||||
wantMatch: false,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
match, err := Argon2id.Verify(tt.plain, tt.hash)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Argon2id.Verify() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if match != tt.wantMatch {
|
||||
t.Errorf("Argon2id.Verify() match = %v, wantMatch %v", match, tt.wantMatch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgon2ID_SaltUniqueness(t *testing.T) {
|
||||
password := "testpassword"
|
||||
iterations := 10
|
||||
|
||||
hashes := make(map[string]bool)
|
||||
salts := make(map[string]bool)
|
||||
|
||||
for i := 0; i < iterations; i++ {
|
||||
hash, err := Argon2id.Hash(password)
|
||||
if err != nil {
|
||||
t.Fatalf("Hash iteration %d failed: %v", i, err)
|
||||
}
|
||||
|
||||
// Check hash uniqueness
|
||||
if hashes[hash] {
|
||||
t.Errorf("Duplicate hash generated at iteration %d", i)
|
||||
}
|
||||
hashes[hash] = true
|
||||
|
||||
// Extract and check salt uniqueness
|
||||
parts := strings.Split(hash, "$")
|
||||
if len(parts) >= 5 {
|
||||
salt := parts[4]
|
||||
if salts[salt] {
|
||||
t.Errorf("Duplicate salt generated at iteration %d", i)
|
||||
}
|
||||
salts[salt] = true
|
||||
}
|
||||
|
||||
// Verify each hash works
|
||||
match, err := Argon2id.Verify(password, hash)
|
||||
if err != nil || !match {
|
||||
t.Errorf("Hash %d failed verification: err=%v, match=%v", i, err, match)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgon2ID_HashFormat(t *testing.T) {
|
||||
password := "testformat"
|
||||
hash, err := Argon2id.Hash(password)
|
||||
if err != nil {
|
||||
t.Fatalf("Hash failed: %v", err)
|
||||
}
|
||||
|
||||
parts := strings.Split(hash, "$")
|
||||
if len(parts) != 6 {
|
||||
t.Fatalf("Expected 6 parts, got %d: %v", len(parts), hash)
|
||||
}
|
||||
|
||||
// Part 0 should be empty (before first $)
|
||||
if parts[0] != "" {
|
||||
t.Errorf("Part 0 should be empty, got: %v", parts[0])
|
||||
}
|
||||
|
||||
// Part 1 should be "argon2id"
|
||||
if parts[1] != "argon2id" {
|
||||
t.Errorf("Part 1 should be 'argon2id', got: %v", parts[1])
|
||||
}
|
||||
|
||||
// Part 2 should be version
|
||||
if !strings.HasPrefix(parts[2], "v=") {
|
||||
t.Errorf("Part 2 should start with 'v=', got: %v", parts[2])
|
||||
}
|
||||
|
||||
// Part 3 should be parameters
|
||||
if !strings.Contains(parts[3], "m=") || !strings.Contains(parts[3], "t=") || !strings.Contains(parts[3], "p=") {
|
||||
t.Errorf("Part 3 should contain m=, t=, and p=, got: %v", parts[3])
|
||||
}
|
||||
|
||||
// Part 4 should be base64 encoded salt
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
t.Errorf("Salt (part 4) is not valid base64: %v", err)
|
||||
}
|
||||
if len(salt) != int(Argon2id.saltLen) {
|
||||
t.Errorf("Salt length is %d, expected %d", len(salt), Argon2id.saltLen)
|
||||
}
|
||||
|
||||
// Part 5 should be base64 encoded hash
|
||||
decodedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
t.Errorf("Hash (part 5) is not valid base64: %v", err)
|
||||
}
|
||||
if len(decodedHash) != int(Argon2id.keyLen) {
|
||||
t.Errorf("Hash length is %d, expected %d", len(decodedHash), Argon2id.keyLen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgon2ID_CaseModification(t *testing.T) {
|
||||
// Passwords should be case-sensitive
|
||||
password := "TestPassword"
|
||||
hash, err := Argon2id.Hash(password)
|
||||
if err != nil {
|
||||
t.Fatalf("Hash failed: %v", err)
|
||||
}
|
||||
|
||||
// Correct case should match
|
||||
match, err := Argon2id.Verify(password, hash)
|
||||
if err != nil || !match {
|
||||
t.Errorf("Correct password failed: err=%v, match=%v", err, match)
|
||||
}
|
||||
|
||||
// Wrong case should not match
|
||||
match, err = Argon2id.Verify("testpassword", hash)
|
||||
if err != nil {
|
||||
t.Errorf("Verify returned error: %v", err)
|
||||
}
|
||||
if match {
|
||||
t.Error("Password verification should be case-sensitive")
|
||||
}
|
||||
|
||||
match, err = Argon2id.Verify("TESTPASSWORD", hash)
|
||||
if err != nil {
|
||||
t.Errorf("Verify returned error: %v", err)
|
||||
}
|
||||
if match {
|
||||
t.Error("Password verification should be case-sensitive")
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgon2ID_InvalidParameters(t *testing.T) {
|
||||
password := "testpassword"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
hash string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "negative memory parameter",
|
||||
hash: "$argon2id$v=19$m=-1,t=1,p=4$dGVzdHNhbHQ$testhash",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "negative time parameter",
|
||||
hash: "$argon2id$v=19$m=65536,t=-1,p=4$dGVzdHNhbHQ$testhash",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "negative parallelism parameter",
|
||||
hash: "$argon2id$v=19$m=65536,t=1,p=-4$dGVzdHNhbHQ$testhash",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "zero memory parameter",
|
||||
hash: "$argon2id$v=19$m=0,t=1,p=4$dGVzdHNhbHQ$testhash",
|
||||
wantErr: false, // argon2 may handle this, we just test parsing
|
||||
},
|
||||
{
|
||||
name: "missing parameter value",
|
||||
hash: "$argon2id$v=19$m=,t=1,p=4$dGVzdHNhbHQ$testhash",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "non-numeric parameter",
|
||||
hash: "$argon2id$v=19$m=abc,t=1,p=4$dGVzdHNhbHQ$testhash",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing parameters separator",
|
||||
hash: "$argon2id$v=19$m=65536 t=1 p=4$dGVzdHNhbHQ$testhash",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := Argon2id.Verify(password, tt.hash)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Argon2id.Verify() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgon2ID_ConcurrentHashing(t *testing.T) {
|
||||
password := "testpassword"
|
||||
concurrency := 10
|
||||
|
||||
type result struct {
|
||||
hash string
|
||||
err error
|
||||
}
|
||||
|
||||
results := make(chan result, concurrency)
|
||||
|
||||
// Generate hashes concurrently
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func() {
|
||||
hash, err := Argon2id.Hash(password)
|
||||
results <- result{hash: hash, err: err}
|
||||
}()
|
||||
}
|
||||
|
||||
// Collect results
|
||||
hashes := make(map[string]bool)
|
||||
for i := 0; i < concurrency; i++ {
|
||||
res := <-results
|
||||
if res.err != nil {
|
||||
t.Errorf("Concurrent hash %d failed: %v", i, res.err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for duplicates
|
||||
if hashes[res.hash] {
|
||||
t.Errorf("Duplicate hash generated in concurrent test")
|
||||
}
|
||||
hashes[res.hash] = true
|
||||
|
||||
// Verify each hash works
|
||||
match, err := Argon2id.Verify(password, res.hash)
|
||||
if err != nil || !match {
|
||||
t.Errorf("Hash %d failed verification: err=%v, match=%v", i, err, match)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgon2ID_VeryLongPassword(t *testing.T) {
|
||||
// Test with extremely long password (100KB)
|
||||
password := strings.Repeat("a", 100*1024)
|
||||
|
||||
hash, err := Argon2id.Hash(password)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to hash very long password: %v", err)
|
||||
}
|
||||
|
||||
match, err := Argon2id.Verify(password, hash)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to verify very long password: %v", err)
|
||||
}
|
||||
|
||||
if !match {
|
||||
t.Error("Very long password failed verification")
|
||||
}
|
||||
|
||||
// Verify wrong password still fails
|
||||
wrongPassword := strings.Repeat("b", 100*1024)
|
||||
match, err = Argon2id.Verify(wrongPassword, hash)
|
||||
if err != nil {
|
||||
t.Errorf("Verify returned error: %v", err)
|
||||
}
|
||||
if match {
|
||||
t.Error("Wrong very long password should not match")
|
||||
}
|
||||
}
|
||||
193
internal/auth/password/password_test.go
Normal file
193
internal/auth/password/password_test.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package password
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHashPassword(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple password",
|
||||
password: "password123",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty password",
|
||||
password: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "long password",
|
||||
password: strings.Repeat("a", 1000),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "special characters",
|
||||
password: "p@ssw0rd!#$%^&*()",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "unicode characters",
|
||||
password: "パスワード123",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hash, err := HashPassword(tt.password)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("HashPassword() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr {
|
||||
// Verify hash format
|
||||
if !strings.HasPrefix(hash, "$argon2id$") {
|
||||
t.Errorf("HashPassword() returned invalid hash format: %v", hash)
|
||||
}
|
||||
// Verify hash has correct number of parts
|
||||
parts := strings.Split(hash, "$")
|
||||
if len(parts) != 6 {
|
||||
t.Errorf("HashPassword() returned hash with incorrect number of parts: %v", len(parts))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyPassword(t *testing.T) {
|
||||
// Pre-generate a known hash for testing
|
||||
testPassword := "testpassword123"
|
||||
testHash, err := HashPassword(testPassword)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate test hash: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
hash string
|
||||
wantMatch bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "correct password",
|
||||
password: testPassword,
|
||||
hash: testHash,
|
||||
wantMatch: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "incorrect password",
|
||||
password: "wrongpassword",
|
||||
hash: testHash,
|
||||
wantMatch: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty password against valid hash",
|
||||
password: "",
|
||||
hash: testHash,
|
||||
wantMatch: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty hash",
|
||||
password: testPassword,
|
||||
hash: "",
|
||||
wantMatch: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid hash format",
|
||||
password: testPassword,
|
||||
hash: "invalid",
|
||||
wantMatch: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "malformed hash - wrong prefix",
|
||||
password: testPassword,
|
||||
hash: "$bcrypt$invalid$hash",
|
||||
wantMatch: false,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
match, err := VerifyPassword(tt.password, tt.hash)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("VerifyPassword() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if match != tt.wantMatch {
|
||||
t.Errorf("VerifyPassword() match = %v, wantMatch %v", match, tt.wantMatch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashPasswordUniqueness(t *testing.T) {
|
||||
password := "testpassword"
|
||||
|
||||
// Generate multiple hashes of the same password
|
||||
hash1, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to hash password: %v", err)
|
||||
}
|
||||
|
||||
hash2, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to hash password: %v", err)
|
||||
}
|
||||
|
||||
// Hashes should be different due to different salts
|
||||
if hash1 == hash2 {
|
||||
t.Error("HashPassword() should generate unique hashes for the same password")
|
||||
}
|
||||
|
||||
// But both should verify correctly
|
||||
match1, err := VerifyPassword(password, hash1)
|
||||
if err != nil || !match1 {
|
||||
t.Errorf("Failed to verify first hash: err=%v, match=%v", err, match1)
|
||||
}
|
||||
|
||||
match2, err := VerifyPassword(password, hash2)
|
||||
if err != nil || !match2 {
|
||||
t.Errorf("Failed to verify second hash: err=%v, match=%v", err, match2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordRoundTrip(t *testing.T) {
|
||||
tests := []string{
|
||||
"simple",
|
||||
"with spaces and special chars !@#$%",
|
||||
"パスワード",
|
||||
strings.Repeat("long", 100),
|
||||
"",
|
||||
}
|
||||
|
||||
for _, password := range tests {
|
||||
t.Run(password, func(t *testing.T) {
|
||||
hash, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() failed: %v", err)
|
||||
}
|
||||
|
||||
match, err := VerifyPassword(password, hash)
|
||||
if err != nil {
|
||||
t.Fatalf("VerifyPassword() failed: %v", err)
|
||||
}
|
||||
|
||||
if !match {
|
||||
t.Error("Password round trip failed: hashed password does not verify")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
430
internal/auth/totp/aes_test.go
Normal file
430
internal/auth/totp/aes_test.go
Normal file
@@ -0,0 +1,430 @@
|
||||
package totp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAESEncrypt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key []byte
|
||||
text []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic encryption with 16-byte key",
|
||||
key: []byte("1234567890123456"), // 16 bytes (AES-128)
|
||||
text: []byte("hello world"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "basic encryption with 24-byte key",
|
||||
key: []byte("123456789012345678901234"), // 24 bytes (AES-192)
|
||||
text: []byte("hello world"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "basic encryption with 32-byte key",
|
||||
key: []byte("12345678901234567890123456789012"), // 32 bytes (AES-256)
|
||||
text: []byte("hello world"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty text",
|
||||
key: []byte("1234567890123456"),
|
||||
text: []byte(""),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "long text",
|
||||
key: []byte("1234567890123456"),
|
||||
text: []byte("This is a much longer text that spans multiple blocks and should be encrypted properly without any issues"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "binary data",
|
||||
key: []byte("1234567890123456"),
|
||||
text: []byte{0x00, 0x01, 0x02, 0x03, 0xFF, 0xFE, 0xFD},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid key length - too short",
|
||||
key: []byte("short"),
|
||||
text: []byte("hello world"),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid key length - 17 bytes",
|
||||
key: []byte("12345678901234567"), // 17 bytes (invalid)
|
||||
text: []byte("hello world"),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "nil key",
|
||||
key: nil,
|
||||
text: []byte("hello world"),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty key",
|
||||
key: []byte(""),
|
||||
text: []byte("hello world"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ciphertext, err := AESEncrypt(tt.key, tt.text)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("AESEncrypt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr {
|
||||
// Verify ciphertext is not empty
|
||||
if len(ciphertext) == 0 {
|
||||
t.Error("AESEncrypt() returned empty ciphertext")
|
||||
}
|
||||
|
||||
// Verify ciphertext length is correct (IV + encrypted text)
|
||||
expectedLen := aes.BlockSize + len(tt.text)
|
||||
if len(ciphertext) != expectedLen {
|
||||
t.Errorf("AESEncrypt() ciphertext length = %d, want %d", len(ciphertext), expectedLen)
|
||||
}
|
||||
|
||||
// Verify ciphertext is different from plaintext (unless text is empty)
|
||||
if len(tt.text) > 0 && bytes.Equal(ciphertext[aes.BlockSize:], tt.text) {
|
||||
t.Error("AESEncrypt() ciphertext matches plaintext")
|
||||
}
|
||||
|
||||
// Verify IV is present and non-zero
|
||||
iv := ciphertext[:aes.BlockSize]
|
||||
allZeros := true
|
||||
for _, b := range iv {
|
||||
if b != 0 {
|
||||
allZeros = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allZeros {
|
||||
t.Error("AESEncrypt() IV is all zeros")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAESDecrypt(t *testing.T) {
|
||||
validKey := []byte("1234567890123456")
|
||||
validText := []byte("hello world")
|
||||
|
||||
// Encrypt some data to use for valid test cases
|
||||
validCiphertext, err := AESEncrypt(validKey, validText)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create valid ciphertext: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key []byte
|
||||
ciphertext []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid decryption",
|
||||
key: validKey,
|
||||
ciphertext: validCiphertext,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "ciphertext too short - empty",
|
||||
key: validKey,
|
||||
ciphertext: []byte(""),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "ciphertext too short - less than block size",
|
||||
key: validKey,
|
||||
ciphertext: []byte("short"),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "ciphertext exactly block size (IV only, no data)",
|
||||
key: validKey,
|
||||
ciphertext: make([]byte, aes.BlockSize),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid key length",
|
||||
key: []byte("short"),
|
||||
ciphertext: validCiphertext,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong key",
|
||||
key: []byte("6543210987654321"),
|
||||
ciphertext: validCiphertext,
|
||||
wantErr: false, // Decryption succeeds but produces garbage
|
||||
},
|
||||
{
|
||||
name: "nil key",
|
||||
key: nil,
|
||||
ciphertext: validCiphertext,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "nil ciphertext",
|
||||
key: validKey,
|
||||
ciphertext: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
plaintext, err := AESDecrypt(tt.key, tt.ciphertext)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("AESDecrypt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr {
|
||||
// For valid decryption with correct key, verify we get original text
|
||||
if tt.name == "valid decryption" && !bytes.Equal(plaintext, validText) {
|
||||
t.Errorf("AESDecrypt() plaintext = %v, want %v", plaintext, validText)
|
||||
}
|
||||
|
||||
// For ciphertext with only IV, plaintext should be empty
|
||||
if tt.name == "ciphertext exactly block size (IV only, no data)" && len(plaintext) != 0 {
|
||||
t.Errorf("AESDecrypt() plaintext length = %d, want 0", len(plaintext))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAESEncryptDecrypt_RoundTrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key []byte
|
||||
text []byte
|
||||
}{
|
||||
{
|
||||
name: "basic round trip",
|
||||
key: []byte("1234567890123456"),
|
||||
text: []byte("hello world"),
|
||||
},
|
||||
{
|
||||
name: "empty text round trip",
|
||||
key: []byte("1234567890123456"),
|
||||
text: []byte(""),
|
||||
},
|
||||
{
|
||||
name: "long text round trip",
|
||||
key: []byte("1234567890123456"),
|
||||
text: []byte("This is a very long text that contains multiple blocks of data and should be encrypted and decrypted correctly without any data loss or corruption"),
|
||||
},
|
||||
{
|
||||
name: "binary data round trip",
|
||||
key: []byte("1234567890123456"),
|
||||
text: []byte{0x00, 0x01, 0x02, 0x03, 0xFF, 0xFE, 0xFD, 0xFC},
|
||||
},
|
||||
{
|
||||
name: "unicode text round trip",
|
||||
key: []byte("1234567890123456"),
|
||||
text: []byte("Hello 世界! 🔐 Encryption"),
|
||||
},
|
||||
{
|
||||
name: "AES-192 round trip",
|
||||
key: []byte("123456789012345678901234"),
|
||||
text: []byte("testing AES-192"),
|
||||
},
|
||||
{
|
||||
name: "AES-256 round trip",
|
||||
key: []byte("12345678901234567890123456789012"),
|
||||
text: []byte("testing AES-256"),
|
||||
},
|
||||
{
|
||||
name: "special characters",
|
||||
key: []byte("1234567890123456"),
|
||||
text: []byte("!@#$%^&*()_+-=[]{}|;':\",./<>?"),
|
||||
},
|
||||
{
|
||||
name: "newlines and tabs",
|
||||
key: []byte("1234567890123456"),
|
||||
text: []byte("line1\nline2\tline3\r\nline4"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Encrypt
|
||||
ciphertext, err := AESEncrypt(tt.key, tt.text)
|
||||
if err != nil {
|
||||
t.Fatalf("AESEncrypt() failed: %v", err)
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
plaintext, err := AESDecrypt(tt.key, ciphertext)
|
||||
if err != nil {
|
||||
t.Fatalf("AESDecrypt() failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify plaintext matches original
|
||||
if !bytes.Equal(plaintext, tt.text) {
|
||||
t.Errorf("Round trip failed: got %v, want %v", plaintext, tt.text)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAESEncrypt_Uniqueness(t *testing.T) {
|
||||
key := []byte("1234567890123456")
|
||||
text := []byte("hello world")
|
||||
iterations := 10
|
||||
|
||||
ciphertexts := make(map[string]bool)
|
||||
|
||||
for i := 0; i < iterations; i++ {
|
||||
ciphertext, err := AESEncrypt(key, text)
|
||||
if err != nil {
|
||||
t.Fatalf("Iteration %d failed: %v", i, err)
|
||||
}
|
||||
|
||||
// Each encryption should produce different ciphertext (due to random IV)
|
||||
ciphertextStr := string(ciphertext)
|
||||
if ciphertexts[ciphertextStr] {
|
||||
t.Errorf("Duplicate ciphertext generated at iteration %d", i)
|
||||
}
|
||||
ciphertexts[ciphertextStr] = true
|
||||
|
||||
// But all should decrypt to the same plaintext
|
||||
plaintext, err := AESDecrypt(key, ciphertext)
|
||||
if err != nil {
|
||||
t.Fatalf("Iteration %d decryption failed: %v", i, err)
|
||||
}
|
||||
if !bytes.Equal(plaintext, text) {
|
||||
t.Errorf("Iteration %d: decrypted text doesn't match original", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAESEncrypt_IVUniqueness(t *testing.T) {
|
||||
key := []byte("1234567890123456")
|
||||
text := []byte("test data")
|
||||
iterations := 20
|
||||
|
||||
ivs := make(map[string]bool)
|
||||
|
||||
for i := 0; i < iterations; i++ {
|
||||
ciphertext, err := AESEncrypt(key, text)
|
||||
if err != nil {
|
||||
t.Fatalf("Iteration %d failed: %v", i, err)
|
||||
}
|
||||
|
||||
// Extract IV (first block)
|
||||
iv := ciphertext[:aes.BlockSize]
|
||||
ivStr := string(iv)
|
||||
|
||||
// Each IV should be unique
|
||||
if ivs[ivStr] {
|
||||
t.Errorf("Duplicate IV generated at iteration %d", i)
|
||||
}
|
||||
ivs[ivStr] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestAESDecrypt_WrongKey(t *testing.T) {
|
||||
originalKey := []byte("1234567890123456")
|
||||
wrongKey := []byte("6543210987654321")
|
||||
text := []byte("secret message")
|
||||
|
||||
// Encrypt with original key
|
||||
ciphertext, err := AESEncrypt(originalKey, text)
|
||||
if err != nil {
|
||||
t.Fatalf("AESEncrypt() failed: %v", err)
|
||||
}
|
||||
|
||||
// Decrypt with wrong key - should not error but produce wrong plaintext
|
||||
plaintext, err := AESDecrypt(wrongKey, ciphertext)
|
||||
if err != nil {
|
||||
t.Fatalf("AESDecrypt() with wrong key failed: %v", err)
|
||||
}
|
||||
|
||||
// Plaintext should be different from original
|
||||
if bytes.Equal(plaintext, text) {
|
||||
t.Error("AESDecrypt() with wrong key produced correct plaintext")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAESDecrypt_CorruptedCiphertext(t *testing.T) {
|
||||
key := []byte("1234567890123456")
|
||||
text := []byte("hello world")
|
||||
|
||||
// Encrypt
|
||||
ciphertext, err := AESEncrypt(key, text)
|
||||
if err != nil {
|
||||
t.Fatalf("AESEncrypt() failed: %v", err)
|
||||
}
|
||||
|
||||
// Corrupt the ciphertext (flip a bit in the encrypted data, not the IV)
|
||||
if len(ciphertext) > aes.BlockSize {
|
||||
corruptedCiphertext := make([]byte, len(ciphertext))
|
||||
copy(corruptedCiphertext, ciphertext)
|
||||
corruptedCiphertext[aes.BlockSize] ^= 0xFF
|
||||
|
||||
// Decrypt corrupted ciphertext - should not error but produce wrong plaintext
|
||||
plaintext, err := AESDecrypt(key, corruptedCiphertext)
|
||||
if err != nil {
|
||||
t.Fatalf("AESDecrypt() with corrupted ciphertext failed: %v", err)
|
||||
}
|
||||
|
||||
// Plaintext should be different from original
|
||||
if bytes.Equal(plaintext, text) {
|
||||
t.Error("AESDecrypt() with corrupted ciphertext produced correct plaintext")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAESEncryptDecrypt_DifferentKeySizes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keySize int
|
||||
}{
|
||||
{"AES-128", 16},
|
||||
{"AES-192", 24},
|
||||
{"AES-256", 32},
|
||||
}
|
||||
|
||||
text := []byte("test message for different key sizes")
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Generate key of specified size
|
||||
key := make([]byte, tt.keySize)
|
||||
for i := range key {
|
||||
key[i] = byte(i)
|
||||
}
|
||||
|
||||
// Encrypt
|
||||
ciphertext, err := AESEncrypt(key, text)
|
||||
if err != nil {
|
||||
t.Fatalf("AESEncrypt() failed: %v", err)
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
plaintext, err := AESDecrypt(key, ciphertext)
|
||||
if err != nil {
|
||||
t.Fatalf("AESDecrypt() failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify
|
||||
if !bytes.Equal(plaintext, text) {
|
||||
t.Errorf("Round trip failed for %s", tt.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
431
internal/auth/totp/totp_test.go
Normal file
431
internal/auth/totp/totp_test.go
Normal file
@@ -0,0 +1,431 @@
|
||||
package totp
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pquerna/otp/totp"
|
||||
)
|
||||
|
||||
func TestGenerateQRCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
siteUrl string
|
||||
secret []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic generation with nil secret",
|
||||
username: "testuser",
|
||||
siteUrl: "opengist.io",
|
||||
secret: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "basic generation with provided secret",
|
||||
username: "testuser",
|
||||
siteUrl: "opengist.io",
|
||||
secret: []byte("1234567890123456"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "username with special characters",
|
||||
username: "test.user",
|
||||
siteUrl: "opengist.io",
|
||||
secret: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "site URL with protocol and port",
|
||||
username: "testuser",
|
||||
siteUrl: "https://opengist.io:6157",
|
||||
secret: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty username",
|
||||
username: "",
|
||||
siteUrl: "opengist.io",
|
||||
secret: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty site URL",
|
||||
username: "testuser",
|
||||
siteUrl: "",
|
||||
secret: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
secretStr, qrcode, secretBytes, err := GenerateQRCode(tt.username, tt.siteUrl, tt.secret)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("GenerateQRCode() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr {
|
||||
// Verify secret string is not empty
|
||||
if secretStr == "" {
|
||||
t.Error("GenerateQRCode() returned empty secret string")
|
||||
}
|
||||
|
||||
// Verify QR code image is generated
|
||||
if qrcode == "" {
|
||||
t.Error("GenerateQRCode() returned empty QR code")
|
||||
}
|
||||
|
||||
// Verify QR code has correct data URI prefix
|
||||
if !strings.HasPrefix(string(qrcode), "data:image/png;base64,") {
|
||||
t.Errorf("QR code does not have correct data URI prefix: %s", qrcode[:50])
|
||||
}
|
||||
|
||||
// Verify QR code is valid base64 after prefix
|
||||
base64Data := strings.TrimPrefix(string(qrcode), "data:image/png;base64,")
|
||||
_, err := base64.StdEncoding.DecodeString(base64Data)
|
||||
if err != nil {
|
||||
t.Errorf("QR code base64 data is invalid: %v", err)
|
||||
}
|
||||
|
||||
// Verify secret bytes are returned
|
||||
if secretBytes == nil {
|
||||
t.Error("GenerateQRCode() returned nil secret bytes")
|
||||
}
|
||||
|
||||
// Verify secret bytes have correct length
|
||||
if len(secretBytes) != secretSize {
|
||||
t.Errorf("Secret bytes length = %d, want %d", len(secretBytes), secretSize)
|
||||
}
|
||||
|
||||
// If a secret was provided, verify it matches what was returned
|
||||
if tt.secret != nil && string(secretBytes) != string(tt.secret) {
|
||||
t.Error("Returned secret bytes do not match provided secret")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateQRCode_SecretUniqueness(t *testing.T) {
|
||||
username := "testuser"
|
||||
siteUrl := "opengist.io"
|
||||
iterations := 10
|
||||
|
||||
secrets := make(map[string]bool)
|
||||
secretBytes := make(map[string]bool)
|
||||
|
||||
for i := 0; i < iterations; i++ {
|
||||
secretStr, _, secret, err := GenerateQRCode(username, siteUrl, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Iteration %d failed: %v", i, err)
|
||||
}
|
||||
|
||||
// Check secret string uniqueness
|
||||
if secrets[secretStr] {
|
||||
t.Errorf("Duplicate secret string generated at iteration %d", i)
|
||||
}
|
||||
secrets[secretStr] = true
|
||||
|
||||
// Check secret bytes uniqueness
|
||||
secretKey := string(secret)
|
||||
if secretBytes[secretKey] {
|
||||
t.Errorf("Duplicate secret bytes generated at iteration %d", i)
|
||||
}
|
||||
secretBytes[secretKey] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateQRCode_WithProvidedSecret(t *testing.T) {
|
||||
username := "testuser"
|
||||
siteUrl := "opengist.io"
|
||||
providedSecret := []byte("mysecret12345678")
|
||||
|
||||
// Generate QR code multiple times with the same secret
|
||||
secretStr1, _, secret1, err := GenerateQRCode(username, siteUrl, providedSecret)
|
||||
if err != nil {
|
||||
t.Fatalf("First generation failed: %v", err)
|
||||
}
|
||||
|
||||
secretStr2, _, secret2, err := GenerateQRCode(username, siteUrl, providedSecret)
|
||||
if err != nil {
|
||||
t.Fatalf("Second generation failed: %v", err)
|
||||
}
|
||||
|
||||
// Secret strings should be the same when using the same input secret
|
||||
if secretStr1 != secretStr2 {
|
||||
t.Error("Secret strings differ when using the same provided secret")
|
||||
}
|
||||
|
||||
// Secret bytes should match the provided secret
|
||||
if string(secret1) != string(providedSecret) {
|
||||
t.Error("Returned secret bytes do not match provided secret (first call)")
|
||||
}
|
||||
if string(secret2) != string(providedSecret) {
|
||||
t.Error("Returned secret bytes do not match provided secret (second call)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateQRCode_ConcurrentGeneration(t *testing.T) {
|
||||
username := "testuser"
|
||||
siteUrl := "opengist.io"
|
||||
concurrency := 10
|
||||
|
||||
type result struct {
|
||||
secretStr string
|
||||
secretBytes []byte
|
||||
err error
|
||||
}
|
||||
|
||||
results := make(chan result, concurrency)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
secretStr, _, secretBytes, err := GenerateQRCode(username, siteUrl, nil)
|
||||
results <- result{secretStr: secretStr, secretBytes: secretBytes, err: err}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
secrets := make(map[string]bool)
|
||||
for res := range results {
|
||||
if res.err != nil {
|
||||
t.Errorf("Concurrent generation failed: %v", res.err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for duplicates
|
||||
if secrets[res.secretStr] {
|
||||
t.Error("Duplicate secret generated in concurrent test")
|
||||
}
|
||||
secrets[res.secretStr] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
// Generate a valid secret for testing
|
||||
_, _, secret, err := GenerateQRCode("testuser", "opengist.io", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate secret: %v", err)
|
||||
}
|
||||
|
||||
// Convert secret bytes to base32 string for TOTP
|
||||
secretStr, _, _, err := GenerateQRCode("testuser", "opengist.io", secret)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate secret string: %v", err)
|
||||
}
|
||||
|
||||
// Generate a valid passcode for the current time
|
||||
validPasscode, err := totp.GenerateCode(secretStr, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate valid passcode: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
passcode string
|
||||
secret string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "valid passcode",
|
||||
passcode: validPasscode,
|
||||
secret: secretStr,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "invalid passcode - wrong digits",
|
||||
passcode: "000000",
|
||||
secret: secretStr,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "invalid passcode - wrong length",
|
||||
passcode: "123",
|
||||
secret: secretStr,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "empty passcode",
|
||||
passcode: "",
|
||||
secret: secretStr,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "empty secret",
|
||||
passcode: validPasscode,
|
||||
secret: "",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "invalid secret format",
|
||||
passcode: validPasscode,
|
||||
secret: "not-a-valid-base32-secret!@#",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "passcode with letters",
|
||||
passcode: "12345A",
|
||||
secret: secretStr,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "passcode with spaces",
|
||||
passcode: "123 456",
|
||||
secret: secretStr,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
valid := Validate(tt.passcode, tt.secret)
|
||||
if valid != tt.wantValid {
|
||||
t.Errorf("Validate() = %v, want %v", valid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_TimeDrift(t *testing.T) {
|
||||
// Generate a valid secret
|
||||
secretStr, _, _, err := GenerateQRCode("testuser", "opengist.io", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate secret: %v", err)
|
||||
}
|
||||
|
||||
// Test that passcodes from previous and next time windows are accepted
|
||||
// (TOTP typically accepts codes from ±1 time window for clock drift)
|
||||
pastTime := time.Now().Add(-30 * time.Second)
|
||||
futureTime := time.Now().Add(30 * time.Second)
|
||||
|
||||
pastPasscode, err := totp.GenerateCode(secretStr, pastTime)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate past passcode: %v", err)
|
||||
}
|
||||
|
||||
futurePasscode, err := totp.GenerateCode(secretStr, futureTime)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate future passcode: %v", err)
|
||||
}
|
||||
|
||||
// These should be valid due to time drift tolerance
|
||||
if !Validate(pastPasscode, secretStr) {
|
||||
t.Error("Validate() rejected passcode from previous time window")
|
||||
}
|
||||
|
||||
if !Validate(futurePasscode, secretStr) {
|
||||
t.Error("Validate() rejected passcode from next time window")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_ExpiredPasscode(t *testing.T) {
|
||||
// Generate a valid secret
|
||||
secretStr, _, _, err := GenerateQRCode("testuser", "opengist.io", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate secret: %v", err)
|
||||
}
|
||||
|
||||
// Generate a passcode from 2 minutes ago (should be expired)
|
||||
oldTime := time.Now().Add(-2 * time.Minute)
|
||||
oldPasscode, err := totp.GenerateCode(secretStr, oldTime)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate old passcode: %v", err)
|
||||
}
|
||||
|
||||
// This should be invalid
|
||||
if Validate(oldPasscode, secretStr) {
|
||||
t.Error("Validate() accepted expired passcode from 2 minutes ago")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_RoundTrip(t *testing.T) {
|
||||
// Test full round trip: generate secret, generate code, validate code
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
siteUrl string
|
||||
}{
|
||||
{
|
||||
name: "basic round trip",
|
||||
username: "testuser",
|
||||
siteUrl: "opengist.io",
|
||||
},
|
||||
{
|
||||
name: "round trip with dot in username",
|
||||
username: "test.user",
|
||||
siteUrl: "opengist.io",
|
||||
},
|
||||
{
|
||||
name: "round trip with hyphen in username",
|
||||
username: "test-user",
|
||||
siteUrl: "opengist.io",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Generate QR code and secret
|
||||
secretStr, _, _, err := GenerateQRCode(tt.username, tt.siteUrl, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateQRCode() failed: %v", err)
|
||||
}
|
||||
|
||||
// Generate a valid passcode
|
||||
passcode, err := totp.GenerateCode(secretStr, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode() failed: %v", err)
|
||||
}
|
||||
|
||||
// Validate the passcode
|
||||
if !Validate(passcode, secretStr) {
|
||||
t.Error("Validate() rejected valid passcode")
|
||||
}
|
||||
|
||||
// Validate wrong passcode fails
|
||||
wrongPasscode := "000000"
|
||||
if passcode == wrongPasscode {
|
||||
wrongPasscode = "111111"
|
||||
}
|
||||
if Validate(wrongPasscode, secretStr) {
|
||||
t.Error("Validate() accepted invalid passcode")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSecret(t *testing.T) {
|
||||
// Test the internal generateSecret function behavior through GenerateQRCode
|
||||
for i := 0; i < 10; i++ {
|
||||
_, _, secret, err := GenerateQRCode("testuser", "opengist.io", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Iteration %d: generateSecret() failed: %v", i, err)
|
||||
}
|
||||
|
||||
if len(secret) != secretSize {
|
||||
t.Errorf("Iteration %d: secret length = %d, want %d", i, len(secret), secretSize)
|
||||
}
|
||||
|
||||
// Verify secret is not all zeros (extremely unlikely with crypto/rand)
|
||||
allZeros := true
|
||||
for _, b := range secret {
|
||||
if b != 0 {
|
||||
allZeros = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allZeros {
|
||||
t.Errorf("Iteration %d: secret is all zeros", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user