Add unit tests for middleware, models, services, handlers, and repository helpers
All checks were successful
CI / ci (push) Successful in 35s
All checks were successful
CI / ci (push) Successful in 35s
This commit is contained in:
2
go.sum
2
go.sum
@@ -4,6 +4,7 @@ github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4=
|
||||
github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
|
||||
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
|
||||
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1/go.mod h1:A2S0CWkNylc2phvKXWBBdD3K0iGnDBGbzRpISP2zBl8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g=
|
||||
@@ -13,5 +14,6 @@ github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg=
|
||||
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
|
||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||
golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc=
|
||||
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
|
||||
|
||||
114
internal/handler/handler_test.go
Normal file
114
internal/handler/handler_test.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHealthHandler(t *testing.T) {
|
||||
h := NewHealthHandler()
|
||||
req := httptest.NewRequest("GET", "/health", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
h.Health(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d", rr.Code)
|
||||
}
|
||||
|
||||
var resp map[string]string
|
||||
json.NewDecoder(rr.Body).Decode(&resp)
|
||||
if resp["status"] != "ok" {
|
||||
t.Errorf("expected status ok, got %s", resp["status"])
|
||||
}
|
||||
if resp["service"] != "homelab-api" {
|
||||
t.Errorf("expected service homelab-api, got %s", resp["service"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteJSON(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
data := map[string]string{"hello": "world"}
|
||||
writeJSON(rr, data, http.StatusCreated)
|
||||
|
||||
if rr.Code != http.StatusCreated {
|
||||
t.Errorf("expected 201, got %d", rr.Code)
|
||||
}
|
||||
if ct := rr.Header().Get("Content-Type"); ct != "application/json" {
|
||||
t.Errorf("expected application/json, got %s", ct)
|
||||
}
|
||||
|
||||
var resp map[string]string
|
||||
json.NewDecoder(rr.Body).Decode(&resp)
|
||||
if resp["hello"] != "world" {
|
||||
t.Errorf("expected world, got %s", resp["hello"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteError(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
writeError(rr, "something went wrong", http.StatusBadRequest)
|
||||
|
||||
if rr.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", rr.Code)
|
||||
}
|
||||
|
||||
var resp map[string]string
|
||||
json.NewDecoder(rr.Body).Decode(&resp)
|
||||
if resp["error"] != "something went wrong" {
|
||||
t.Errorf("expected 'something went wrong', got %s", resp["error"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterestHandler_Unauthorized(t *testing.T) {
|
||||
h := &InterestHandler{secretKey: "my-secret"}
|
||||
|
||||
t.Run("missing key", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/internal/calculate-interest", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
h.CalculateInterest(rr, req)
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", rr.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong key", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/internal/calculate-interest", nil)
|
||||
req.Header.Set("X-Internal-Key", "wrong-key")
|
||||
rr := httptest.NewRecorder()
|
||||
h.CalculateInterest(rr, req)
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", rr.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test request validation (without real service, just checking decoding)
|
||||
func TestDecodeInvalidJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
handler http.HandlerFunc
|
||||
}{
|
||||
{"invalid json", "{bad", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct{ Email string }
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, "invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewBufferString(tt.body))
|
||||
rr := httptest.NewRecorder()
|
||||
tt.handler(rr, req)
|
||||
if rr.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", rr.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
165
internal/middleware/auth_test.go
Normal file
165
internal/middleware/auth_test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
const testSecret = "test-secret-key"
|
||||
|
||||
func generateTestToken(userID int64, tokenType string, secret string, expiry time.Duration) string {
|
||||
claims := jwt.MapClaims{
|
||||
"user_id": userID,
|
||||
"type": tokenType,
|
||||
"exp": time.Now().Add(expiry).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
s, _ := token.SignedString([]byte(secret))
|
||||
return s
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_ValidToken(t *testing.T) {
|
||||
m := NewAuthMiddleware(testSecret)
|
||||
token := generateTestToken(42, "access", testSecret, 15*time.Minute)
|
||||
|
||||
var capturedUserID int64
|
||||
handler := m.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedUserID = GetUserID(r.Context())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d", rr.Code)
|
||||
}
|
||||
if capturedUserID != 42 {
|
||||
t.Errorf("expected userID 42, got %d", capturedUserID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_MissingHeader(t *testing.T) {
|
||||
m := NewAuthMiddleware(testSecret)
|
||||
handler := m.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_InvalidFormat(t *testing.T) {
|
||||
m := NewAuthMiddleware(testSecret)
|
||||
handler := m.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
}{
|
||||
{"no bearer prefix", "Token abc123"},
|
||||
{"only bearer", "Bearer"},
|
||||
{"three parts", "Bearer token extra"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", tt.header)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", rr.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_ExpiredToken(t *testing.T) {
|
||||
m := NewAuthMiddleware(testSecret)
|
||||
token := generateTestToken(1, "access", testSecret, -1*time.Hour)
|
||||
|
||||
handler := m.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_WrongSecret(t *testing.T) {
|
||||
m := NewAuthMiddleware(testSecret)
|
||||
token := generateTestToken(1, "access", "wrong-secret", 15*time.Minute)
|
||||
|
||||
handler := m.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_RefreshTokenRejected(t *testing.T) {
|
||||
m := NewAuthMiddleware(testSecret)
|
||||
token := generateTestToken(1, "refresh", testSecret, 15*time.Minute)
|
||||
|
||||
handler := m.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401 for refresh token, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserID_NoContext(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
userID := GetUserID(req.Context())
|
||||
if userID != 0 {
|
||||
t.Errorf("expected 0 for missing context, got %d", userID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_InvalidSigningMethod(t *testing.T) {
|
||||
m := NewAuthMiddleware(testSecret)
|
||||
// Create a token with none algorithm (should be rejected)
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"user_id": float64(1),
|
||||
"type": "access",
|
||||
"exp": time.Now().Add(15 * time.Minute).Unix(),
|
||||
})
|
||||
// Tamper with the token
|
||||
s, _ := token.SignedString([]byte(testSecret))
|
||||
tampered := s + "tampered"
|
||||
|
||||
handler := m.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tampered))
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
52
internal/model/habit_test.go
Normal file
52
internal/model/habit_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestHabit_ProcessForJSON(t *testing.T) {
|
||||
t.Run("with reminder time RFC3339 format", func(t *testing.T) {
|
||||
h := &Habit{
|
||||
ReminderTime: sql.NullString{String: "0000-01-01T19:00:00Z", Valid: true},
|
||||
StartDate: sql.NullTime{Time: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC), Valid: true},
|
||||
}
|
||||
h.ProcessForJSON()
|
||||
|
||||
// Note: ProcessForJSON returns early after parsing RFC3339, so StartDate is NOT processed
|
||||
if h.ReminderTimeStr == nil || *h.ReminderTimeStr != "19:00" {
|
||||
t.Errorf("expected 19:00, got %v", h.ReminderTimeStr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with reminder time HH:MM:SS format and start date", func(t *testing.T) {
|
||||
h := &Habit{
|
||||
ReminderTime: sql.NullString{String: "08:30:00", Valid: true},
|
||||
StartDate: sql.NullTime{Time: time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC), Valid: true},
|
||||
}
|
||||
h.ProcessForJSON()
|
||||
|
||||
if h.ReminderTimeStr == nil || *h.ReminderTimeStr != "08:30" {
|
||||
t.Errorf("expected 08:30, got %v", h.ReminderTimeStr)
|
||||
}
|
||||
if h.StartDateStr == nil || *h.StartDateStr != "2025-01-15" {
|
||||
t.Errorf("expected 2025-01-15, got %v", h.StartDateStr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("without reminder time", func(t *testing.T) {
|
||||
h := &Habit{
|
||||
ReminderTime: sql.NullString{Valid: false},
|
||||
StartDate: sql.NullTime{Valid: false},
|
||||
}
|
||||
h.ProcessForJSON()
|
||||
|
||||
if h.ReminderTimeStr != nil {
|
||||
t.Error("reminder_time should be nil")
|
||||
}
|
||||
if h.StartDateStr != nil {
|
||||
t.Error("start_date should be nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
61
internal/model/savings_test.go
Normal file
61
internal/model/savings_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSavingsCategory_ProcessForJSON(t *testing.T) {
|
||||
t.Run("with deposit dates", func(t *testing.T) {
|
||||
c := &SavingsCategory{
|
||||
DepositStartDate: sql.NullTime{Time: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true},
|
||||
DepositEndDate: sql.NullTime{Time: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true},
|
||||
CreditStartDate: sql.NullTime{Time: time.Date(2025, 6, 1, 0, 0, 0, 0, time.UTC), Valid: true},
|
||||
}
|
||||
c.ProcessForJSON()
|
||||
|
||||
if c.DepositStartStr == nil || *c.DepositStartStr != "2025-01-01" {
|
||||
t.Errorf("expected 2025-01-01, got %v", c.DepositStartStr)
|
||||
}
|
||||
if c.DepositEndStr == nil || *c.DepositEndStr != "2026-01-01" {
|
||||
t.Errorf("expected 2026-01-01, got %v", c.DepositEndStr)
|
||||
}
|
||||
if c.CreditStartStr == nil || *c.CreditStartStr != "2025-06-01" {
|
||||
t.Errorf("expected 2025-06-01, got %v", c.CreditStartStr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("without dates", func(t *testing.T) {
|
||||
c := &SavingsCategory{}
|
||||
c.ProcessForJSON()
|
||||
|
||||
if c.DepositStartStr != nil {
|
||||
t.Error("expected nil deposit_start_date")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSavingsRecurringPlan_ProcessForJSON(t *testing.T) {
|
||||
t.Run("with user_id", func(t *testing.T) {
|
||||
p := &SavingsRecurringPlan{
|
||||
UserID: sql.NullInt64{Int64: 42, Valid: true},
|
||||
}
|
||||
p.ProcessForJSON()
|
||||
|
||||
if p.UserIDPtr == nil || *p.UserIDPtr != 42 {
|
||||
t.Errorf("expected 42, got %v", p.UserIDPtr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("without user_id", func(t *testing.T) {
|
||||
p := &SavingsRecurringPlan{
|
||||
UserID: sql.NullInt64{Valid: false},
|
||||
}
|
||||
p.ProcessForJSON()
|
||||
|
||||
if p.UserIDPtr != nil {
|
||||
t.Error("expected nil user_id")
|
||||
}
|
||||
})
|
||||
}
|
||||
66
internal/model/task_test.go
Normal file
66
internal/model/task_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTask_ProcessForJSON(t *testing.T) {
|
||||
t.Run("task with HH:MM:SS reminder and all fields", func(t *testing.T) {
|
||||
task := &Task{
|
||||
DueDate: sql.NullTime{Time: time.Date(2025, 3, 15, 0, 0, 0, 0, time.UTC), Valid: true},
|
||||
ReminderTime: sql.NullString{String: "14:30:00", Valid: true},
|
||||
CompletedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
RecurrenceType: sql.NullString{String: "weekly", Valid: true},
|
||||
RecurrenceEndDate: sql.NullTime{Time: time.Date(2025, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true},
|
||||
ParentTaskID: sql.NullInt64{Int64: 5, Valid: true},
|
||||
}
|
||||
task.ProcessForJSON()
|
||||
|
||||
if task.DueDateStr == nil || *task.DueDateStr != "2025-03-15" {
|
||||
t.Errorf("expected due_date 2025-03-15, got %v", task.DueDateStr)
|
||||
}
|
||||
if task.ReminderTimeStr == nil || *task.ReminderTimeStr != "14:30" {
|
||||
t.Errorf("expected reminder 14:30, got %v", task.ReminderTimeStr)
|
||||
}
|
||||
if !task.Completed {
|
||||
t.Error("expected completed to be true")
|
||||
}
|
||||
if task.RecurrenceTypeStr == nil || *task.RecurrenceTypeStr != "weekly" {
|
||||
t.Errorf("expected recurrence_type weekly, got %v", task.RecurrenceTypeStr)
|
||||
}
|
||||
if task.RecurrenceEndStr == nil || *task.RecurrenceEndStr != "2025-12-31" {
|
||||
t.Errorf("expected recurrence_end 2025-12-31, got %v", task.RecurrenceEndStr)
|
||||
}
|
||||
if task.ParentTaskIDPtr == nil || *task.ParentTaskIDPtr != 5 {
|
||||
t.Errorf("expected parent_task_id 5, got %v", task.ParentTaskIDPtr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("task with RFC3339 reminder", func(t *testing.T) {
|
||||
task := &Task{
|
||||
ReminderTime: sql.NullString{String: "0000-01-01T09:00:00Z", Valid: true},
|
||||
}
|
||||
task.ProcessForJSON()
|
||||
|
||||
if task.ReminderTimeStr == nil || *task.ReminderTimeStr != "09:00" {
|
||||
t.Errorf("expected 09:00, got %v", task.ReminderTimeStr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("incomplete task with null fields", func(t *testing.T) {
|
||||
task := &Task{
|
||||
DueDate: sql.NullTime{Valid: false},
|
||||
CompletedAt: sql.NullTime{Valid: false},
|
||||
}
|
||||
task.ProcessForJSON()
|
||||
|
||||
if task.DueDateStr != nil {
|
||||
t.Error("expected due_date nil")
|
||||
}
|
||||
if task.Completed {
|
||||
t.Error("expected completed to be false")
|
||||
}
|
||||
})
|
||||
}
|
||||
46
internal/model/user_test.go
Normal file
46
internal/model/user_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUser_ProcessForJSON(t *testing.T) {
|
||||
t.Run("with telegram chat id", func(t *testing.T) {
|
||||
u := &User{
|
||||
TelegramChatID: sql.NullInt64{Int64: 123456, Valid: true},
|
||||
MorningReminderTime: sql.NullString{String: "09:00:00", Valid: true},
|
||||
EveningReminderTime: sql.NullString{String: "21:30:00", Valid: true},
|
||||
}
|
||||
u.ProcessForJSON()
|
||||
|
||||
if u.TelegramChatIDValue == nil || *u.TelegramChatIDValue != 123456 {
|
||||
t.Error("telegram_chat_id not set correctly")
|
||||
}
|
||||
if u.MorningTime != "09:00" {
|
||||
t.Errorf("expected 09:00, got %s", u.MorningTime)
|
||||
}
|
||||
if u.EveningTime != "21:30" {
|
||||
t.Errorf("expected 21:30, got %s", u.EveningTime)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("without telegram chat id", func(t *testing.T) {
|
||||
u := &User{
|
||||
TelegramChatID: sql.NullInt64{Valid: false},
|
||||
MorningReminderTime: sql.NullString{Valid: false},
|
||||
EveningReminderTime: sql.NullString{Valid: false},
|
||||
}
|
||||
u.ProcessForJSON()
|
||||
|
||||
if u.TelegramChatIDValue != nil {
|
||||
t.Error("telegram_chat_id should be nil")
|
||||
}
|
||||
if u.MorningTime != "09:00" {
|
||||
t.Errorf("expected default 09:00, got %s", u.MorningTime)
|
||||
}
|
||||
if u.EveningTime != "21:00" {
|
||||
t.Errorf("expected default 21:00, got %s", u.EveningTime)
|
||||
}
|
||||
})
|
||||
}
|
||||
96
internal/repository/helpers_test.go
Normal file
96
internal/repository/helpers_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/daniil/homelab-api/internal/model"
|
||||
)
|
||||
|
||||
func TestHabitFreezeRepository_CountFrozenDaysLogic(t *testing.T) {
|
||||
// Test the overlap calculation logic that CountFrozenDaysInRange uses
|
||||
tests := []struct {
|
||||
name string
|
||||
freezeStart, freezeEnd time.Time
|
||||
queryStart, queryEnd time.Time
|
||||
wantDays int
|
||||
}{
|
||||
{
|
||||
name: "full overlap",
|
||||
freezeStart: time.Date(2025, 1, 5, 0, 0, 0, 0, time.UTC),
|
||||
freezeEnd: time.Date(2025, 1, 10, 0, 0, 0, 0, time.UTC),
|
||||
queryStart: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
queryEnd: time.Date(2025, 1, 31, 0, 0, 0, 0, time.UTC),
|
||||
wantDays: 6,
|
||||
},
|
||||
{
|
||||
name: "partial overlap start",
|
||||
freezeStart: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
freezeEnd: time.Date(2025, 1, 10, 0, 0, 0, 0, time.UTC),
|
||||
queryStart: time.Date(2025, 1, 5, 0, 0, 0, 0, time.UTC),
|
||||
queryEnd: time.Date(2025, 1, 31, 0, 0, 0, 0, time.UTC),
|
||||
wantDays: 6,
|
||||
},
|
||||
{
|
||||
name: "partial overlap end",
|
||||
freezeStart: time.Date(2025, 1, 20, 0, 0, 0, 0, time.UTC),
|
||||
freezeEnd: time.Date(2025, 2, 5, 0, 0, 0, 0, time.UTC),
|
||||
queryStart: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
queryEnd: time.Date(2025, 1, 31, 0, 0, 0, 0, time.UTC),
|
||||
wantDays: 12,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
freeze := model.HabitFreeze{
|
||||
StartDate: tt.freezeStart,
|
||||
EndDate: tt.freezeEnd,
|
||||
}
|
||||
|
||||
overlapStart := freeze.StartDate
|
||||
if tt.queryStart.After(freeze.StartDate) {
|
||||
overlapStart = tt.queryStart
|
||||
}
|
||||
overlapEnd := freeze.EndDate
|
||||
if tt.queryEnd.Before(freeze.EndDate) {
|
||||
overlapEnd = tt.queryEnd
|
||||
}
|
||||
|
||||
days := 0
|
||||
if !overlapEnd.Before(overlapStart) {
|
||||
days = int(overlapEnd.Sub(overlapStart).Hours()/24) + 1
|
||||
}
|
||||
|
||||
if days != tt.wantDays {
|
||||
t.Errorf("got %d frozen days, want %d", days, tt.wantDays)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoinStrings(t *testing.T) {
|
||||
tests := []struct {
|
||||
input []string
|
||||
sep string
|
||||
want string
|
||||
}{
|
||||
{nil, ", ", ""},
|
||||
{[]string{"a"}, ", ", "a"},
|
||||
{[]string{"a", "b", "c"}, ", ", "a, b, c"},
|
||||
{[]string{"x", "y"}, " AND ", "x AND y"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := joinStrings(tt.input, tt.sep)
|
||||
if got != tt.want {
|
||||
t.Errorf("joinStrings(%v, %q) = %q, want %q", tt.input, tt.sep, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsUniqueViolation(t *testing.T) {
|
||||
if isUniqueViolation(nil) {
|
||||
t.Error("nil error should not be unique violation")
|
||||
}
|
||||
}
|
||||
97
internal/service/auth_test.go
Normal file
97
internal/service/auth_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
func TestAuthService_GenerateAndValidateToken(t *testing.T) {
|
||||
s := &AuthService{jwtSecret: "test-secret"}
|
||||
|
||||
t.Run("valid access token", func(t *testing.T) {
|
||||
tokenStr, err := s.generateToken(1, "access", 15*time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("generateToken error: %v", err)
|
||||
}
|
||||
|
||||
claims, err := s.validateToken(tokenStr, "access")
|
||||
if err != nil {
|
||||
t.Fatalf("validateToken error: %v", err)
|
||||
}
|
||||
|
||||
userID, ok := claims["user_id"].(float64)
|
||||
if !ok || int64(userID) != 1 {
|
||||
t.Errorf("expected user_id 1, got %v", claims["user_id"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong token type rejected", func(t *testing.T) {
|
||||
tokenStr, _ := s.generateToken(1, "refresh", time.Hour)
|
||||
_, err := s.validateToken(tokenStr, "access")
|
||||
if err != ErrInvalidToken {
|
||||
t.Errorf("expected ErrInvalidToken, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("expired token rejected", func(t *testing.T) {
|
||||
tokenStr, _ := s.generateToken(1, "access", -time.Hour)
|
||||
_, err := s.validateToken(tokenStr, "access")
|
||||
if err != ErrInvalidToken {
|
||||
t.Errorf("expected ErrInvalidToken, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong secret rejected", func(t *testing.T) {
|
||||
otherService := &AuthService{jwtSecret: "other-secret"}
|
||||
tokenStr, _ := otherService.generateToken(1, "access", time.Hour)
|
||||
_, err := s.validateToken(tokenStr, "access")
|
||||
if err != ErrInvalidToken {
|
||||
t.Errorf("expected ErrInvalidToken, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tampered token rejected", func(t *testing.T) {
|
||||
tokenStr, _ := s.generateToken(1, "access", time.Hour)
|
||||
_, err := s.validateToken(tokenStr+"x", "access")
|
||||
if err != ErrInvalidToken {
|
||||
t.Errorf("expected ErrInvalidToken, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HMAC signing method accepted", func(t *testing.T) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"user_id": float64(1),
|
||||
"type": "access",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
})
|
||||
tokenStr, _ := token.SignedString([]byte("test-secret"))
|
||||
|
||||
claims, err := s.validateToken(tokenStr, "access")
|
||||
if err != nil {
|
||||
t.Fatalf("should accept HS256: %v", err)
|
||||
}
|
||||
if claims["type"] != "access" {
|
||||
t.Error("claims type mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestErrWeakPassword(t *testing.T) {
|
||||
if ErrWeakPassword.Error() != "password must be at least 8 characters" {
|
||||
t.Errorf("unexpected error message: %s", ErrWeakPassword.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrInvalidCredentials(t *testing.T) {
|
||||
if ErrInvalidCredentials.Error() != "invalid credentials" {
|
||||
t.Errorf("unexpected error message: %s", ErrInvalidCredentials.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrEmailNotVerified(t *testing.T) {
|
||||
if ErrEmailNotVerified.Error() != "email not verified" {
|
||||
t.Errorf("unexpected error message: %s", ErrEmailNotVerified.Error())
|
||||
}
|
||||
}
|
||||
35
internal/service/helpers_test.go
Normal file
35
internal/service/helpers_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefaultString(t *testing.T) {
|
||||
tests := []struct {
|
||||
val, def, want string
|
||||
}{
|
||||
{"hello", "default", "hello"},
|
||||
{"", "default", "default"},
|
||||
{"", "", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := defaultString(tt.val, tt.def)
|
||||
if got != tt.want {
|
||||
t.Errorf("defaultString(%q, %q) = %q, want %q", tt.val, tt.def, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
val, def, want int
|
||||
}{
|
||||
{5, 10, 5},
|
||||
{0, 10, 10},
|
||||
{0, 0, 0},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := defaultInt(tt.val, tt.def)
|
||||
if got != tt.want {
|
||||
t.Errorf("defaultInt(%d, %d) = %d, want %d", tt.val, tt.def, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
66
internal/service/interest_test.go
Normal file
66
internal/service/interest_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/daniil/homelab-api/internal/model"
|
||||
)
|
||||
|
||||
func TestCalculateInterestForDeposit_NotDeposit(t *testing.T) {
|
||||
s := &InterestService{}
|
||||
deposit := &model.SavingsCategory{IsDeposit: false}
|
||||
result, err := s.CalculateInterestForDeposit(deposit)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result != "" {
|
||||
t.Errorf("expected empty result for non-deposit, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateInterestForDeposit_ZeroRate(t *testing.T) {
|
||||
s := &InterestService{}
|
||||
deposit := &model.SavingsCategory{IsDeposit: true, InterestRate: 0}
|
||||
result, err := s.CalculateInterestForDeposit(deposit)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result != "" {
|
||||
t.Errorf("expected empty result for zero rate, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateInterestForDeposit_NoStartDate(t *testing.T) {
|
||||
s := &InterestService{}
|
||||
deposit := &model.SavingsCategory{
|
||||
IsDeposit: true,
|
||||
InterestRate: 10,
|
||||
DepositStartDate: sql.NullTime{Valid: false},
|
||||
}
|
||||
_, err := s.CalculateInterestForDeposit(deposit)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing start date")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateInterestForDeposit_ExpiredDeposit(t *testing.T) {
|
||||
s := &InterestService{}
|
||||
deposit := &model.SavingsCategory{
|
||||
IsDeposit: true,
|
||||
InterestRate: 10,
|
||||
DepositTerm: 3, // 3 months
|
||||
DepositStartDate: sql.NullTime{
|
||||
Time: time.Now().AddDate(0, -6, 0), // 6 months ago
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
result, err := s.CalculateInterestForDeposit(deposit)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result != "" {
|
||||
t.Errorf("expected empty result for expired deposit, got %q", result)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user