From 8d9fe818f456f786f91f5f178e5d73b4e30ea0bb Mon Sep 17 00:00:00 2001 From: Cosmo Date: Sun, 1 Mar 2026 02:32:59 +0000 Subject: [PATCH] Add unit tests for middleware, models, services, handlers, and repository helpers --- go.sum | 2 + internal/handler/handler_test.go | 114 +++++++++++++++++++ internal/middleware/auth_test.go | 165 ++++++++++++++++++++++++++++ internal/model/habit_test.go | 52 +++++++++ internal/model/savings_test.go | 61 ++++++++++ internal/model/task_test.go | 66 +++++++++++ internal/model/user_test.go | 46 ++++++++ internal/repository/helpers_test.go | 96 ++++++++++++++++ internal/service/auth_test.go | 97 ++++++++++++++++ internal/service/helpers_test.go | 35 ++++++ internal/service/interest_test.go | 66 +++++++++++ 11 files changed, 800 insertions(+) create mode 100644 internal/handler/handler_test.go create mode 100644 internal/middleware/auth_test.go create mode 100644 internal/model/habit_test.go create mode 100644 internal/model/savings_test.go create mode 100644 internal/model/task_test.go create mode 100644 internal/model/user_test.go create mode 100644 internal/repository/helpers_test.go create mode 100644 internal/service/auth_test.go create mode 100644 internal/service/helpers_test.go create mode 100644 internal/service/interest_test.go diff --git a/go.sum b/go.sum index 86f3590..a160070 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,7 @@ github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4= github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1/go.mod h1:A2S0CWkNylc2phvKXWBBdD3K0iGnDBGbzRpISP2zBl8= github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw= github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= @@ -13,5 +14,6 @@ github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= diff --git a/internal/handler/handler_test.go b/internal/handler/handler_test.go new file mode 100644 index 0000000..f20486a --- /dev/null +++ b/internal/handler/handler_test.go @@ -0,0 +1,114 @@ +package handler + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestHealthHandler(t *testing.T) { + h := NewHealthHandler() + req := httptest.NewRequest("GET", "/health", nil) + rr := httptest.NewRecorder() + h.Health(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rr.Code) + } + + var resp map[string]string + json.NewDecoder(rr.Body).Decode(&resp) + if resp["status"] != "ok" { + t.Errorf("expected status ok, got %s", resp["status"]) + } + if resp["service"] != "homelab-api" { + t.Errorf("expected service homelab-api, got %s", resp["service"]) + } +} + +func TestWriteJSON(t *testing.T) { + rr := httptest.NewRecorder() + data := map[string]string{"hello": "world"} + writeJSON(rr, data, http.StatusCreated) + + if rr.Code != http.StatusCreated { + t.Errorf("expected 201, got %d", rr.Code) + } + if ct := rr.Header().Get("Content-Type"); ct != "application/json" { + t.Errorf("expected application/json, got %s", ct) + } + + var resp map[string]string + json.NewDecoder(rr.Body).Decode(&resp) + if resp["hello"] != "world" { + t.Errorf("expected world, got %s", resp["hello"]) + } +} + +func TestWriteError(t *testing.T) { + rr := httptest.NewRecorder() + writeError(rr, "something went wrong", http.StatusBadRequest) + + if rr.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", rr.Code) + } + + var resp map[string]string + json.NewDecoder(rr.Body).Decode(&resp) + if resp["error"] != "something went wrong" { + t.Errorf("expected 'something went wrong', got %s", resp["error"]) + } +} + +func TestInterestHandler_Unauthorized(t *testing.T) { + h := &InterestHandler{secretKey: "my-secret"} + + t.Run("missing key", func(t *testing.T) { + req := httptest.NewRequest("POST", "/internal/calculate-interest", nil) + rr := httptest.NewRecorder() + h.CalculateInterest(rr, req) + if rr.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", rr.Code) + } + }) + + t.Run("wrong key", func(t *testing.T) { + req := httptest.NewRequest("POST", "/internal/calculate-interest", nil) + req.Header.Set("X-Internal-Key", "wrong-key") + rr := httptest.NewRecorder() + h.CalculateInterest(rr, req) + if rr.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", rr.Code) + } + }) +} + +// Test request validation (without real service, just checking decoding) +func TestDecodeInvalidJSON(t *testing.T) { + tests := []struct { + name string + body string + handler http.HandlerFunc + }{ + {"invalid json", "{bad", func(w http.ResponseWriter, r *http.Request) { + var req struct{ Email string } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, "invalid request body", http.StatusBadRequest) + return + } + }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(tt.body)) + rr := httptest.NewRecorder() + tt.handler(rr, req) + if rr.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", rr.Code) + } + }) + } +} diff --git a/internal/middleware/auth_test.go b/internal/middleware/auth_test.go new file mode 100644 index 0000000..fb8e681 --- /dev/null +++ b/internal/middleware/auth_test.go @@ -0,0 +1,165 @@ +package middleware + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +const testSecret = "test-secret-key" + +func generateTestToken(userID int64, tokenType string, secret string, expiry time.Duration) string { + claims := jwt.MapClaims{ + "user_id": userID, + "type": tokenType, + "exp": time.Now().Add(expiry).Unix(), + "iat": time.Now().Unix(), + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + s, _ := token.SignedString([]byte(secret)) + return s +} + +func TestAuthMiddleware_ValidToken(t *testing.T) { + m := NewAuthMiddleware(testSecret) + token := generateTestToken(42, "access", testSecret, 15*time.Minute) + + var capturedUserID int64 + handler := m.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedUserID = GetUserID(r.Context()) + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rr.Code) + } + if capturedUserID != 42 { + t.Errorf("expected userID 42, got %d", capturedUserID) + } +} + +func TestAuthMiddleware_MissingHeader(t *testing.T) { + m := NewAuthMiddleware(testSecret) + handler := m.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", rr.Code) + } +} + +func TestAuthMiddleware_InvalidFormat(t *testing.T) { + m := NewAuthMiddleware(testSecret) + handler := m.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + + tests := []struct { + name string + header string + }{ + {"no bearer prefix", "Token abc123"}, + {"only bearer", "Bearer"}, + {"three parts", "Bearer token extra"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", tt.header) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", rr.Code) + } + }) + } +} + +func TestAuthMiddleware_ExpiredToken(t *testing.T) { + m := NewAuthMiddleware(testSecret) + token := generateTestToken(1, "access", testSecret, -1*time.Hour) + + handler := m.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", rr.Code) + } +} + +func TestAuthMiddleware_WrongSecret(t *testing.T) { + m := NewAuthMiddleware(testSecret) + token := generateTestToken(1, "access", "wrong-secret", 15*time.Minute) + + handler := m.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", rr.Code) + } +} + +func TestAuthMiddleware_RefreshTokenRejected(t *testing.T) { + m := NewAuthMiddleware(testSecret) + token := generateTestToken(1, "refresh", testSecret, 15*time.Minute) + + handler := m.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Errorf("expected 401 for refresh token, got %d", rr.Code) + } +} + +func TestGetUserID_NoContext(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + userID := GetUserID(req.Context()) + if userID != 0 { + t.Errorf("expected 0 for missing context, got %d", userID) + } +} + +func TestAuthMiddleware_InvalidSigningMethod(t *testing.T) { + m := NewAuthMiddleware(testSecret) + // Create a token with none algorithm (should be rejected) + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "user_id": float64(1), + "type": "access", + "exp": time.Now().Add(15 * time.Minute).Unix(), + }) + // Tamper with the token + s, _ := token.SignedString([]byte(testSecret)) + tampered := s + "tampered" + + handler := m.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tampered)) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", rr.Code) + } +} diff --git a/internal/model/habit_test.go b/internal/model/habit_test.go new file mode 100644 index 0000000..3abf8b3 --- /dev/null +++ b/internal/model/habit_test.go @@ -0,0 +1,52 @@ +package model + +import ( + "database/sql" + "testing" + "time" +) + +func TestHabit_ProcessForJSON(t *testing.T) { + t.Run("with reminder time RFC3339 format", func(t *testing.T) { + h := &Habit{ + ReminderTime: sql.NullString{String: "0000-01-01T19:00:00Z", Valid: true}, + StartDate: sql.NullTime{Time: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC), Valid: true}, + } + h.ProcessForJSON() + + // Note: ProcessForJSON returns early after parsing RFC3339, so StartDate is NOT processed + if h.ReminderTimeStr == nil || *h.ReminderTimeStr != "19:00" { + t.Errorf("expected 19:00, got %v", h.ReminderTimeStr) + } + }) + + t.Run("with reminder time HH:MM:SS format and start date", func(t *testing.T) { + h := &Habit{ + ReminderTime: sql.NullString{String: "08:30:00", Valid: true}, + StartDate: sql.NullTime{Time: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC), Valid: true}, + } + h.ProcessForJSON() + + if h.ReminderTimeStr == nil || *h.ReminderTimeStr != "08:30" { + t.Errorf("expected 08:30, got %v", h.ReminderTimeStr) + } + if h.StartDateStr == nil || *h.StartDateStr != "2025-01-15" { + t.Errorf("expected 2025-01-15, got %v", h.StartDateStr) + } + }) + + t.Run("without reminder time", func(t *testing.T) { + h := &Habit{ + ReminderTime: sql.NullString{Valid: false}, + StartDate: sql.NullTime{Valid: false}, + } + h.ProcessForJSON() + + if h.ReminderTimeStr != nil { + t.Error("reminder_time should be nil") + } + if h.StartDateStr != nil { + t.Error("start_date should be nil") + } + }) +} diff --git a/internal/model/savings_test.go b/internal/model/savings_test.go new file mode 100644 index 0000000..abe400f --- /dev/null +++ b/internal/model/savings_test.go @@ -0,0 +1,61 @@ +package model + +import ( + "database/sql" + "testing" + "time" +) + +func TestSavingsCategory_ProcessForJSON(t *testing.T) { + t.Run("with deposit dates", func(t *testing.T) { + c := &SavingsCategory{ + DepositStartDate: sql.NullTime{Time: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + DepositEndDate: sql.NullTime{Time: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + CreditStartDate: sql.NullTime{Time: time.Date(2025, 6, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + } + c.ProcessForJSON() + + if c.DepositStartStr == nil || *c.DepositStartStr != "2025-01-01" { + t.Errorf("expected 2025-01-01, got %v", c.DepositStartStr) + } + if c.DepositEndStr == nil || *c.DepositEndStr != "2026-01-01" { + t.Errorf("expected 2026-01-01, got %v", c.DepositEndStr) + } + if c.CreditStartStr == nil || *c.CreditStartStr != "2025-06-01" { + t.Errorf("expected 2025-06-01, got %v", c.CreditStartStr) + } + }) + + t.Run("without dates", func(t *testing.T) { + c := &SavingsCategory{} + c.ProcessForJSON() + + if c.DepositStartStr != nil { + t.Error("expected nil deposit_start_date") + } + }) +} + +func TestSavingsRecurringPlan_ProcessForJSON(t *testing.T) { + t.Run("with user_id", func(t *testing.T) { + p := &SavingsRecurringPlan{ + UserID: sql.NullInt64{Int64: 42, Valid: true}, + } + p.ProcessForJSON() + + if p.UserIDPtr == nil || *p.UserIDPtr != 42 { + t.Errorf("expected 42, got %v", p.UserIDPtr) + } + }) + + t.Run("without user_id", func(t *testing.T) { + p := &SavingsRecurringPlan{ + UserID: sql.NullInt64{Valid: false}, + } + p.ProcessForJSON() + + if p.UserIDPtr != nil { + t.Error("expected nil user_id") + } + }) +} diff --git a/internal/model/task_test.go b/internal/model/task_test.go new file mode 100644 index 0000000..b2aa01d --- /dev/null +++ b/internal/model/task_test.go @@ -0,0 +1,66 @@ +package model + +import ( + "database/sql" + "testing" + "time" +) + +func TestTask_ProcessForJSON(t *testing.T) { + t.Run("task with HH:MM:SS reminder and all fields", func(t *testing.T) { + task := &Task{ + DueDate: sql.NullTime{Time: time.Date(2025, 3, 15, 0, 0, 0, 0, time.UTC), Valid: true}, + ReminderTime: sql.NullString{String: "14:30:00", Valid: true}, + CompletedAt: sql.NullTime{Time: time.Now(), Valid: true}, + RecurrenceType: sql.NullString{String: "weekly", Valid: true}, + RecurrenceEndDate: sql.NullTime{Time: time.Date(2025, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + ParentTaskID: sql.NullInt64{Int64: 5, Valid: true}, + } + task.ProcessForJSON() + + if task.DueDateStr == nil || *task.DueDateStr != "2025-03-15" { + t.Errorf("expected due_date 2025-03-15, got %v", task.DueDateStr) + } + if task.ReminderTimeStr == nil || *task.ReminderTimeStr != "14:30" { + t.Errorf("expected reminder 14:30, got %v", task.ReminderTimeStr) + } + if !task.Completed { + t.Error("expected completed to be true") + } + if task.RecurrenceTypeStr == nil || *task.RecurrenceTypeStr != "weekly" { + t.Errorf("expected recurrence_type weekly, got %v", task.RecurrenceTypeStr) + } + if task.RecurrenceEndStr == nil || *task.RecurrenceEndStr != "2025-12-31" { + t.Errorf("expected recurrence_end 2025-12-31, got %v", task.RecurrenceEndStr) + } + if task.ParentTaskIDPtr == nil || *task.ParentTaskIDPtr != 5 { + t.Errorf("expected parent_task_id 5, got %v", task.ParentTaskIDPtr) + } + }) + + t.Run("task with RFC3339 reminder", func(t *testing.T) { + task := &Task{ + ReminderTime: sql.NullString{String: "0000-01-01T09:00:00Z", Valid: true}, + } + task.ProcessForJSON() + + if task.ReminderTimeStr == nil || *task.ReminderTimeStr != "09:00" { + t.Errorf("expected 09:00, got %v", task.ReminderTimeStr) + } + }) + + t.Run("incomplete task with null fields", func(t *testing.T) { + task := &Task{ + DueDate: sql.NullTime{Valid: false}, + CompletedAt: sql.NullTime{Valid: false}, + } + task.ProcessForJSON() + + if task.DueDateStr != nil { + t.Error("expected due_date nil") + } + if task.Completed { + t.Error("expected completed to be false") + } + }) +} diff --git a/internal/model/user_test.go b/internal/model/user_test.go new file mode 100644 index 0000000..730e3d5 --- /dev/null +++ b/internal/model/user_test.go @@ -0,0 +1,46 @@ +package model + +import ( + "database/sql" + "testing" +) + +func TestUser_ProcessForJSON(t *testing.T) { + t.Run("with telegram chat id", func(t *testing.T) { + u := &User{ + TelegramChatID: sql.NullInt64{Int64: 123456, Valid: true}, + MorningReminderTime: sql.NullString{String: "09:00:00", Valid: true}, + EveningReminderTime: sql.NullString{String: "21:30:00", Valid: true}, + } + u.ProcessForJSON() + + if u.TelegramChatIDValue == nil || *u.TelegramChatIDValue != 123456 { + t.Error("telegram_chat_id not set correctly") + } + if u.MorningTime != "09:00" { + t.Errorf("expected 09:00, got %s", u.MorningTime) + } + if u.EveningTime != "21:30" { + t.Errorf("expected 21:30, got %s", u.EveningTime) + } + }) + + t.Run("without telegram chat id", func(t *testing.T) { + u := &User{ + TelegramChatID: sql.NullInt64{Valid: false}, + MorningReminderTime: sql.NullString{Valid: false}, + EveningReminderTime: sql.NullString{Valid: false}, + } + u.ProcessForJSON() + + if u.TelegramChatIDValue != nil { + t.Error("telegram_chat_id should be nil") + } + if u.MorningTime != "09:00" { + t.Errorf("expected default 09:00, got %s", u.MorningTime) + } + if u.EveningTime != "21:00" { + t.Errorf("expected default 21:00, got %s", u.EveningTime) + } + }) +} diff --git a/internal/repository/helpers_test.go b/internal/repository/helpers_test.go new file mode 100644 index 0000000..f239b0f --- /dev/null +++ b/internal/repository/helpers_test.go @@ -0,0 +1,96 @@ +package repository + +import ( + "testing" + "time" + + "github.com/daniil/homelab-api/internal/model" +) + +func TestHabitFreezeRepository_CountFrozenDaysLogic(t *testing.T) { + // Test the overlap calculation logic that CountFrozenDaysInRange uses + tests := []struct { + name string + freezeStart, freezeEnd time.Time + queryStart, queryEnd time.Time + wantDays int + }{ + { + name: "full overlap", + freezeStart: time.Date(2025, 1, 5, 0, 0, 0, 0, time.UTC), + freezeEnd: time.Date(2025, 1, 10, 0, 0, 0, 0, time.UTC), + queryStart: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), + queryEnd: time.Date(2025, 1, 31, 0, 0, 0, 0, time.UTC), + wantDays: 6, + }, + { + name: "partial overlap start", + freezeStart: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), + freezeEnd: time.Date(2025, 1, 10, 0, 0, 0, 0, time.UTC), + queryStart: time.Date(2025, 1, 5, 0, 0, 0, 0, time.UTC), + queryEnd: time.Date(2025, 1, 31, 0, 0, 0, 0, time.UTC), + wantDays: 6, + }, + { + name: "partial overlap end", + freezeStart: time.Date(2025, 1, 20, 0, 0, 0, 0, time.UTC), + freezeEnd: time.Date(2025, 2, 5, 0, 0, 0, 0, time.UTC), + queryStart: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), + queryEnd: time.Date(2025, 1, 31, 0, 0, 0, 0, time.UTC), + wantDays: 12, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + freeze := model.HabitFreeze{ + StartDate: tt.freezeStart, + EndDate: tt.freezeEnd, + } + + overlapStart := freeze.StartDate + if tt.queryStart.After(freeze.StartDate) { + overlapStart = tt.queryStart + } + overlapEnd := freeze.EndDate + if tt.queryEnd.Before(freeze.EndDate) { + overlapEnd = tt.queryEnd + } + + days := 0 + if !overlapEnd.Before(overlapStart) { + days = int(overlapEnd.Sub(overlapStart).Hours()/24) + 1 + } + + if days != tt.wantDays { + t.Errorf("got %d frozen days, want %d", days, tt.wantDays) + } + }) + } +} + +func TestJoinStrings(t *testing.T) { + tests := []struct { + input []string + sep string + want string + }{ + {nil, ", ", ""}, + {[]string{"a"}, ", ", "a"}, + {[]string{"a", "b", "c"}, ", ", "a, b, c"}, + {[]string{"x", "y"}, " AND ", "x AND y"}, + } + + for _, tt := range tests { + got := joinStrings(tt.input, tt.sep) + if got != tt.want { + t.Errorf("joinStrings(%v, %q) = %q, want %q", tt.input, tt.sep, got, tt.want) + } + } +} + +func TestIsUniqueViolation(t *testing.T) { + if isUniqueViolation(nil) { + t.Error("nil error should not be unique violation") + } +} diff --git a/internal/service/auth_test.go b/internal/service/auth_test.go new file mode 100644 index 0000000..ed4770f --- /dev/null +++ b/internal/service/auth_test.go @@ -0,0 +1,97 @@ +package service + +import ( + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +func TestAuthService_GenerateAndValidateToken(t *testing.T) { + s := &AuthService{jwtSecret: "test-secret"} + + t.Run("valid access token", func(t *testing.T) { + tokenStr, err := s.generateToken(1, "access", 15*time.Minute) + if err != nil { + t.Fatalf("generateToken error: %v", err) + } + + claims, err := s.validateToken(tokenStr, "access") + if err != nil { + t.Fatalf("validateToken error: %v", err) + } + + userID, ok := claims["user_id"].(float64) + if !ok || int64(userID) != 1 { + t.Errorf("expected user_id 1, got %v", claims["user_id"]) + } + }) + + t.Run("wrong token type rejected", func(t *testing.T) { + tokenStr, _ := s.generateToken(1, "refresh", time.Hour) + _, err := s.validateToken(tokenStr, "access") + if err != ErrInvalidToken { + t.Errorf("expected ErrInvalidToken, got %v", err) + } + }) + + t.Run("expired token rejected", func(t *testing.T) { + tokenStr, _ := s.generateToken(1, "access", -time.Hour) + _, err := s.validateToken(tokenStr, "access") + if err != ErrInvalidToken { + t.Errorf("expected ErrInvalidToken, got %v", err) + } + }) + + t.Run("wrong secret rejected", func(t *testing.T) { + otherService := &AuthService{jwtSecret: "other-secret"} + tokenStr, _ := otherService.generateToken(1, "access", time.Hour) + _, err := s.validateToken(tokenStr, "access") + if err != ErrInvalidToken { + t.Errorf("expected ErrInvalidToken, got %v", err) + } + }) + + t.Run("tampered token rejected", func(t *testing.T) { + tokenStr, _ := s.generateToken(1, "access", time.Hour) + _, err := s.validateToken(tokenStr+"x", "access") + if err != ErrInvalidToken { + t.Errorf("expected ErrInvalidToken, got %v", err) + } + }) + + t.Run("HMAC signing method accepted", func(t *testing.T) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "user_id": float64(1), + "type": "access", + "exp": time.Now().Add(time.Hour).Unix(), + }) + tokenStr, _ := token.SignedString([]byte("test-secret")) + + claims, err := s.validateToken(tokenStr, "access") + if err != nil { + t.Fatalf("should accept HS256: %v", err) + } + if claims["type"] != "access" { + t.Error("claims type mismatch") + } + }) +} + +func TestErrWeakPassword(t *testing.T) { + if ErrWeakPassword.Error() != "password must be at least 8 characters" { + t.Errorf("unexpected error message: %s", ErrWeakPassword.Error()) + } +} + +func TestErrInvalidCredentials(t *testing.T) { + if ErrInvalidCredentials.Error() != "invalid credentials" { + t.Errorf("unexpected error message: %s", ErrInvalidCredentials.Error()) + } +} + +func TestErrEmailNotVerified(t *testing.T) { + if ErrEmailNotVerified.Error() != "email not verified" { + t.Errorf("unexpected error message: %s", ErrEmailNotVerified.Error()) + } +} diff --git a/internal/service/helpers_test.go b/internal/service/helpers_test.go new file mode 100644 index 0000000..9c8d2c3 --- /dev/null +++ b/internal/service/helpers_test.go @@ -0,0 +1,35 @@ +package service + +import "testing" + +func TestDefaultString(t *testing.T) { + tests := []struct { + val, def, want string + }{ + {"hello", "default", "hello"}, + {"", "default", "default"}, + {"", "", ""}, + } + for _, tt := range tests { + got := defaultString(tt.val, tt.def) + if got != tt.want { + t.Errorf("defaultString(%q, %q) = %q, want %q", tt.val, tt.def, got, tt.want) + } + } +} + +func TestDefaultInt(t *testing.T) { + tests := []struct { + val, def, want int + }{ + {5, 10, 5}, + {0, 10, 10}, + {0, 0, 0}, + } + for _, tt := range tests { + got := defaultInt(tt.val, tt.def) + if got != tt.want { + t.Errorf("defaultInt(%d, %d) = %d, want %d", tt.val, tt.def, got, tt.want) + } + } +} diff --git a/internal/service/interest_test.go b/internal/service/interest_test.go new file mode 100644 index 0000000..2708bef --- /dev/null +++ b/internal/service/interest_test.go @@ -0,0 +1,66 @@ +package service + +import ( + "database/sql" + "testing" + "time" + + "github.com/daniil/homelab-api/internal/model" +) + +func TestCalculateInterestForDeposit_NotDeposit(t *testing.T) { + s := &InterestService{} + deposit := &model.SavingsCategory{IsDeposit: false} + result, err := s.CalculateInterestForDeposit(deposit) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != "" { + t.Errorf("expected empty result for non-deposit, got %q", result) + } +} + +func TestCalculateInterestForDeposit_ZeroRate(t *testing.T) { + s := &InterestService{} + deposit := &model.SavingsCategory{IsDeposit: true, InterestRate: 0} + result, err := s.CalculateInterestForDeposit(deposit) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != "" { + t.Errorf("expected empty result for zero rate, got %q", result) + } +} + +func TestCalculateInterestForDeposit_NoStartDate(t *testing.T) { + s := &InterestService{} + deposit := &model.SavingsCategory{ + IsDeposit: true, + InterestRate: 10, + DepositStartDate: sql.NullTime{Valid: false}, + } + _, err := s.CalculateInterestForDeposit(deposit) + if err == nil { + t.Error("expected error for missing start date") + } +} + +func TestCalculateInterestForDeposit_ExpiredDeposit(t *testing.T) { + s := &InterestService{} + deposit := &model.SavingsCategory{ + IsDeposit: true, + InterestRate: 10, + DepositTerm: 3, // 3 months + DepositStartDate: sql.NullTime{ + Time: time.Now().AddDate(0, -6, 0), // 6 months ago + Valid: true, + }, + } + result, err := s.CalculateInterestForDeposit(deposit) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != "" { + t.Errorf("expected empty result for expired deposit, got %q", result) + } +}