package middleware import ( "fmt" "net/http" "net/http/httptest" "testing" "time" "github.com/golang-jwt/jwt/v5" ) const testSecret = "test-secret-key" func generateTestToken(userID int64, tokenType string, secret string, expiry time.Duration) string { claims := jwt.MapClaims{ "user_id": userID, "type": tokenType, "exp": time.Now().Add(expiry).Unix(), "iat": time.Now().Unix(), } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) s, _ := token.SignedString([]byte(secret)) return s } func TestAuthMiddleware_ValidToken(t *testing.T) { m := NewAuthMiddleware(testSecret) token := generateTestToken(42, "access", testSecret, 15*time.Minute) var capturedUserID int64 handler := m.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { capturedUserID = GetUserID(r.Context()) w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("Authorization", "Bearer "+token) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected 200, got %d", rr.Code) } if capturedUserID != 42 { t.Errorf("expected userID 42, got %d", capturedUserID) } } func TestAuthMiddleware_MissingHeader(t *testing.T) { m := NewAuthMiddleware(testSecret) handler := m.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest("GET", "/test", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", rr.Code) } } func TestAuthMiddleware_InvalidFormat(t *testing.T) { m := NewAuthMiddleware(testSecret) handler := m.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) tests := []struct { name string header string }{ {"no bearer prefix", "Token abc123"}, {"only bearer", "Bearer"}, {"three parts", "Bearer token extra"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("Authorization", tt.header) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", rr.Code) } }) } } func TestAuthMiddleware_ExpiredToken(t *testing.T) { m := NewAuthMiddleware(testSecret) token := generateTestToken(1, "access", testSecret, -1*time.Hour) handler := m.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("Authorization", "Bearer "+token) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", rr.Code) } } func TestAuthMiddleware_WrongSecret(t *testing.T) { m := NewAuthMiddleware(testSecret) token := generateTestToken(1, "access", "wrong-secret", 15*time.Minute) handler := m.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("Authorization", "Bearer "+token) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", rr.Code) } } func TestAuthMiddleware_RefreshTokenRejected(t *testing.T) { m := NewAuthMiddleware(testSecret) token := generateTestToken(1, "refresh", testSecret, 15*time.Minute) handler := m.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("Authorization", "Bearer "+token) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("expected 401 for refresh token, got %d", rr.Code) } } func TestGetUserID_NoContext(t *testing.T) { req := httptest.NewRequest("GET", "/test", nil) userID := GetUserID(req.Context()) if userID != 0 { t.Errorf("expected 0 for missing context, got %d", userID) } } func TestAuthMiddleware_InvalidSigningMethod(t *testing.T) { m := NewAuthMiddleware(testSecret) // Create a token with none algorithm (should be rejected) token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "user_id": float64(1), "type": "access", "exp": time.Now().Add(15 * time.Minute).Unix(), }) // Tamper with the token s, _ := token.SignedString([]byte(testSecret)) tampered := s + "tampered" handler := m.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tampered)) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", rr.Code) } }