From 2a50e50771512ba08495a67520f5ac872c02cfb5 Mon Sep 17 00:00:00 2001 From: Cosmo Date: Mon, 16 Feb 2026 06:48:09 +0000 Subject: [PATCH] feat(savings): Add savings module with categories, transactions, recurring plans - Categories: regular, deposits, credits, recurring, multi-user, accounts - Transactions: deposits and withdrawals with user tracking - Recurring plans: monthly payment obligations per user - Stats: overdues calculation with allocation algorithm - Excludes is_account categories from total sums - Documentation: docs/SAVINGS.md --- cmd/api/main.go | 41 +- docs/SAVINGS.md | 56 ++ internal/bot/handlers.go | 101 ++- internal/handler/habit_freeze.go | 135 ++++ internal/handler/habits.go | 8 + internal/handler/savings.go | 489 +++++++++++++ internal/model/habit.go | 10 + internal/model/habit_freeze.go | 21 + internal/model/savings.go | 228 ++++++ internal/model/task.go | 89 ++- internal/repository/db.go | 22 + internal/repository/habit.go | 263 ++++++- internal/repository/habit_freeze.go | 151 ++++ internal/repository/savings.go | 1045 +++++++++++++++++++++++++++ internal/repository/task.go | 53 +- internal/scheduler/scheduler.go | 110 ++- internal/service/habit.go | 137 +++- internal/service/task.go | 113 ++- 18 files changed, 2910 insertions(+), 162 deletions(-) create mode 100644 docs/SAVINGS.md create mode 100644 internal/handler/habit_freeze.go create mode 100644 internal/handler/savings.go create mode 100644 internal/model/habit_freeze.go create mode 100644 internal/model/savings.go create mode 100644 internal/repository/habit_freeze.go create mode 100644 internal/repository/savings.go diff --git a/cmd/api/main.go b/cmd/api/main.go index a4cd57f..dcbd487 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -38,11 +38,13 @@ func main() { habitRepo := repository.NewHabitRepository(db) taskRepo := repository.NewTaskRepository(db) emailTokenRepo := repository.NewEmailTokenRepository(db) + habitFreezeRepo := repository.NewHabitFreezeRepository(db) + savingsRepo := repository.NewSavingsRepository(db) // Initialize services emailService := service.NewEmailService(cfg.ResendAPIKey, cfg.FromEmail, cfg.FromName, cfg.AppURL) authService := service.NewAuthService(userRepo, emailTokenRepo, emailService, cfg.JWTSecret) - habitService := service.NewHabitService(habitRepo) + habitService := service.NewHabitService(habitRepo, habitFreezeRepo) taskService := service.NewTaskService(taskRepo) // Initialize Telegram bot @@ -57,7 +59,7 @@ func main() { } // Initialize scheduler - sched := scheduler.New(telegramBot, userRepo, taskRepo, habitRepo) + sched := scheduler.New(telegramBot, userRepo, taskRepo, habitRepo, habitFreezeRepo) sched.Start() defer sched.Stop() @@ -67,6 +69,8 @@ func main() { taskHandler := handler.NewTaskHandler(taskService) healthHandler := handler.NewHealthHandler() profileHandler := handler.NewProfileHandler(userRepo) + habitFreezeHandler := handler.NewHabitFreezeHandler(habitFreezeRepo, habitRepo) + savingsHandler := handler.NewSavingsHandler(savingsRepo) // Initialize middleware authMiddleware := customMiddleware.NewAuthMiddleware(cfg.JWTSecret) @@ -124,6 +128,11 @@ func main() { r.Get("/habits/{id}/logs", habitHandler.GetLogs) r.Delete("/habits/{id}/logs/{logId}", habitHandler.DeleteLog) + // Habit freezes + r.Get("/habits/{id}/freezes", habitFreezeHandler.List) + r.Post("/habits/{id}/freezes", habitFreezeHandler.Create) + r.Delete("/habits/{id}/freezes/{freezeId}", habitFreezeHandler.Delete) + // Stats r.Get("/habits/stats", habitHandler.Stats) r.Get("/habits/{id}/stats", habitHandler.HabitStats) @@ -137,6 +146,34 @@ func main() { r.Delete("/tasks/{id}", taskHandler.Delete) r.Post("/tasks/{id}/complete", taskHandler.Complete) r.Post("/tasks/{id}/uncomplete", taskHandler.Uncomplete) + + // Savings routes + r.Get("/savings/categories", savingsHandler.ListCategories) + r.Post("/savings/categories", savingsHandler.CreateCategory) + r.Get("/savings/categories/{id}", savingsHandler.GetCategory) + r.Put("/savings/categories/{id}", savingsHandler.UpdateCategory) + r.Delete("/savings/categories/{id}", savingsHandler.DeleteCategory) + + // Savings category members + r.Get("/savings/categories/{id}/members", savingsHandler.ListMembers) + r.Post("/savings/categories/{id}/members", savingsHandler.AddMember) + r.Delete("/savings/categories/{id}/members/{userId}", savingsHandler.RemoveMember) + + // Savings recurring plans + r.Get("/savings/categories/{id}/recurring-plans", savingsHandler.ListRecurringPlans) + r.Post("/savings/categories/{id}/recurring-plans", savingsHandler.CreateRecurringPlan) + r.Delete("/savings/recurring-plans/{planId}", savingsHandler.DeleteRecurringPlan) + r.Put("/savings/recurring-plans/{planId}", savingsHandler.UpdateRecurringPlan) + + // Savings transactions + r.Get("/savings/transactions", savingsHandler.ListTransactions) + r.Post("/savings/transactions", savingsHandler.CreateTransaction) + r.Get("/savings/transactions/{id}", savingsHandler.GetTransaction) + r.Put("/savings/transactions/{id}", savingsHandler.UpdateTransaction) + r.Delete("/savings/transactions/{id}", savingsHandler.DeleteTransaction) + + // Savings stats + r.Get("/savings/stats", savingsHandler.Stats) }) port := os.Getenv("PORT") diff --git a/docs/SAVINGS.md b/docs/SAVINGS.md new file mode 100644 index 0000000..5672495 --- /dev/null +++ b/docs/SAVINGS.md @@ -0,0 +1,56 @@ +# Savings Module + +Модуль накоплений для Pulse App. Позволяет отслеживать накопления, вклады, кредиты и регулярные платежи. + +## Функционал + +### Категории +- **Обычные накопления** — копилки с произвольными пополнениями +- **Депозиты (is_deposit)** — вклады с процентной ставкой и сроком +- **Кредиты (is_credit)** — отслеживание погашения кредитов +- **Регулярные (is_recurring)** — категории с ежемесячными обязательствами +- **Мультипользовательские (is_multi)** — общие категории для нескольких пользователей +- **Счета (is_account)** — транзитные счета (исключаются из общих сумм) + +### Recurring Plans +Для категорий с is_recurring=true создаются планы платежей: +- effective — дата начала действия плана +- amount — сумма ежемесячного платежа +- day — день месяца для платежа (1-28) +- user_id — пользователь (для multi-категорий) + +### Алгоритм расчёта просрочек + +1. Определение периода: начало = MAX(category.created_at, earliest_plan.effective) +2. Построение списка месяцев от начала до текущего +3. Расчёт required для каждого месяца по активному плану +4. Allocation deposits: каждый депозит сначала покрывает свой месяц, потом предыдущие +5. Overdues: все прошлые месяцы с remaining > 0 + +### Расчёт сумм +- total_deposits — сумма пополнений пользователя (исключая is_account) +- total_withdrawals — сумма снятий пользователя (исключая is_account) +- total_balance — сумма балансов всех категорий пользователя + +## API Endpoints + +### Categories +- GET /savings/categories — список категорий +- POST /savings/categories — создать категорию +- PUT /savings/categories/:id — обновить +- DELETE /savings/categories/:id — удалить + +### Transactions +- GET /savings/transactions — список транзакций +- POST /savings/transactions — создать +- PUT /savings/transactions/:id — обновить +- DELETE /savings/transactions/:id — удалить + +### Stats +- GET /savings/stats — статистика (балансы, monthly_payments, overdues) + +### Recurring Plans +- GET /savings/categories/:id/recurring-plans — планы +- POST /savings/categories/:id/recurring-plans — создать +- PUT /savings/recurring-plans/:planId — обновить +- DELETE /savings/recurring-plans/:planId — удалить diff --git a/internal/bot/handlers.go b/internal/bot/handlers.go index f89277e..d3cec4c 100644 --- a/internal/bot/handlers.go +++ b/internal/bot/handlers.go @@ -61,6 +61,39 @@ func (b *Bot) handleCallback(callback *tgbotapi.CallbackQuery) { } action := parts[0] + + // Handle checkhabit with optional date (checkhabit_ or checkhabit__yesterday) + if action == "checkhabit" { + id, _ := strconv.ParseInt(parts[1], 10, 64) + + logDate := time.Now() + dateLabel := "сегодня" + if len(parts) >= 3 && parts[2] == "yesterday" { + logDate = time.Now().AddDate(0, 0, -1) + dateLabel = "вчера" + } + + log := &model.HabitLog{ + HabitID: id, + UserID: user.ID, + Date: logDate, + Count: 1, + } + err = b.habitRepo.CreateLog(log) + if err != nil { + if strings.Contains(err.Error(), "duplicate") || strings.Contains(err.Error(), "already") { + b.answerCallback(callback.ID, "⚠️ Уже отмечено за эту дату") + } else { + b.answerCallback(callback.ID, "❌ Ошибка") + } + return + } + + b.answerCallback(callback.ID, fmt.Sprintf("✅ Отмечено за %s!", dateLabel)) + b.refreshHabitsMessage(chatID, messageID, user.ID) + return + } + id, _ := strconv.ParseInt(parts[1], 10, 64) switch action { @@ -81,21 +114,6 @@ func (b *Bot) handleCallback(callback *tgbotapi.CallbackQuery) { } b.answerCallback(callback.ID, "🗑 Задача удалена") b.refreshTasksMessage(chatID, messageID, user.ID) - - case "checkhabit": - log := &model.HabitLog{ - HabitID: id, - UserID: user.ID, - Date: time.Now(), - Count: 1, - } - err = b.habitRepo.CreateLog(log) - if err != nil { - b.answerCallback(callback.ID, "❌ Ошибка") - return - } - b.answerCallback(callback.ID, "✅ Привычка отмечена!") - b.refreshHabitsMessage(chatID, messageID, user.ID) } } @@ -204,11 +222,15 @@ func (b *Bot) buildHabitsMessage(habits []model.Habit, userID int64) (string, *t text := "🎯 Привычки на сегодня:\n\n" var rows [][]tgbotapi.InlineKeyboardButton + + yesterday := time.Now().AddDate(0, 0, -1).Truncate(24 * time.Hour) for _, habit := range todayHabits { - completed, _ := b.habitRepo.IsHabitCompletedToday(habit.ID, userID) + completedToday, _ := b.habitRepo.IsHabitCompletedToday(habit.ID, userID) + completedYesterday, _ := b.habitRepo.IsHabitCompletedOnDate(habit.ID, userID, yesterday) + status := "⬜" - if completed { + if completedToday { status = "✅" } @@ -218,15 +240,31 @@ func (b *Bot) buildHabitsMessage(habits []model.Habit, userID int64) (string, *t } text += "\n" - if !completed { - row := []tgbotapi.InlineKeyboardButton{ - tgbotapi.NewInlineKeyboardButtonData(fmt.Sprintf("✅ %s", habit.Name), fmt.Sprintf("checkhabit_%d", habit.ID)), - } - rows = append(rows, row) + // Build buttons row + var btnRow []tgbotapi.InlineKeyboardButton + + if !completedToday { + btnRow = append(btnRow, tgbotapi.NewInlineKeyboardButtonData( + fmt.Sprintf("✅ %s", habit.Name), + fmt.Sprintf("checkhabit_%d", habit.ID), + )) + } + + // Add "За вчера" button if not completed yesterday + if !completedYesterday { + btnRow = append(btnRow, tgbotapi.NewInlineKeyboardButtonData( + "📅 За вчера", + fmt.Sprintf("checkhabit_%d_yesterday", habit.ID), + )) + } + + if len(btnRow) > 0 { + rows = append(rows, btnRow) } } if len(rows) == 0 { + text += "\n✨ Всё выполнено!" return text, nil } @@ -235,28 +273,13 @@ func (b *Bot) buildHabitsMessage(habits []model.Habit, userID int64) (string, *t } func (b *Bot) handleStart(msg *tgbotapi.Message) { - text := fmt.Sprintf(`👋 Привет! Я бот Pulse. - -Твой Chat ID: %d - -Скопируй его и вставь в настройках Pulse для получения уведомлений. - -Доступные команды: -/tasks — задачи на сегодня -/habits — привычки на сегодня -/help — справка`, msg.Chat.ID) + text := fmt.Sprintf("👋 Привет! Я бот Pulse.\n\nТвой Chat ID: %d\n\nСкопируй его и вставь в настройках Pulse для получения уведомлений.\n\nДоступные команды:\n/tasks — задачи на сегодня\n/habits — привычки на сегодня\n/help — справка", msg.Chat.ID) b.SendMessage(msg.Chat.ID, text) } func (b *Bot) handleHelp(msg *tgbotapi.Message) { - text := `📚 Справка по командам: - -/start — получить твой Chat ID -/tasks — список задач на сегодня -/habits — список привычек - -💡 Чтобы получать уведомления, добавь свой Chat ID в настройках Pulse.` + text := "📚 Справка по командам:\n\n/start — получить твой Chat ID\n/tasks — список задач на сегодня\n/habits — список привычек\n\n💡 Чтобы получать уведомления, добавь свой Chat ID в настройках Pulse." b.SendMessage(msg.Chat.ID, text) } diff --git a/internal/handler/habit_freeze.go b/internal/handler/habit_freeze.go new file mode 100644 index 0000000..da7a8c9 --- /dev/null +++ b/internal/handler/habit_freeze.go @@ -0,0 +1,135 @@ +package handler + +import ( + "encoding/json" + "errors" + "net/http" + "strconv" + "time" + + "github.com/go-chi/chi/v5" + + "github.com/daniil/homelab-api/internal/middleware" + "github.com/daniil/homelab-api/internal/model" + "github.com/daniil/homelab-api/internal/repository" +) + +type HabitFreezeHandler struct { + freezeRepo *repository.HabitFreezeRepository + habitRepo *repository.HabitRepository +} + +func NewHabitFreezeHandler(freezeRepo *repository.HabitFreezeRepository, habitRepo *repository.HabitRepository) *HabitFreezeHandler { + return &HabitFreezeHandler{ + freezeRepo: freezeRepo, + habitRepo: habitRepo, + } +} + +func (h *HabitFreezeHandler) Create(w http.ResponseWriter, r *http.Request) { + userID := middleware.GetUserID(r.Context()) + habitID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) + if err != nil { + writeError(w, "invalid habit id", http.StatusBadRequest) + return + } + + // Verify habit exists and belongs to user + if _, err := h.habitRepo.GetByID(habitID, userID); err != nil { + if errors.Is(err, repository.ErrHabitNotFound) { + writeError(w, "habit not found", http.StatusNotFound) + return + } + writeError(w, "failed to fetch habit", http.StatusInternalServerError) + return + } + + var req model.CreateHabitFreezeRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, "invalid request body", http.StatusBadRequest) + return + } + + if req.StartDate == "" || req.EndDate == "" { + writeError(w, "start_date and end_date are required", http.StatusBadRequest) + return + } + + startDate, err := time.Parse("2006-01-02", req.StartDate) + if err != nil { + writeError(w, "invalid start_date format", http.StatusBadRequest) + return + } + + endDate, err := time.Parse("2006-01-02", req.EndDate) + if err != nil { + writeError(w, "invalid end_date format", http.StatusBadRequest) + return + } + + freeze := &model.HabitFreeze{ + HabitID: habitID, + UserID: userID, + StartDate: startDate, + EndDate: endDate, + Reason: req.Reason, + } + + if err := h.freezeRepo.Create(freeze); err != nil { + if errors.Is(err, repository.ErrInvalidDateRange) { + writeError(w, "end_date must be after start_date", http.StatusBadRequest) + return + } + writeError(w, "failed to create freeze", http.StatusInternalServerError) + return + } + + writeJSON(w, freeze, http.StatusCreated) +} + +func (h *HabitFreezeHandler) List(w http.ResponseWriter, r *http.Request) { + userID := middleware.GetUserID(r.Context()) + habitID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) + if err != nil { + writeError(w, "invalid habit id", http.StatusBadRequest) + return + } + + // Verify habit exists and belongs to user + if _, err := h.habitRepo.GetByID(habitID, userID); err != nil { + if errors.Is(err, repository.ErrHabitNotFound) { + writeError(w, "habit not found", http.StatusNotFound) + return + } + writeError(w, "failed to fetch habit", http.StatusInternalServerError) + return + } + + freezes, err := h.freezeRepo.GetByHabitID(habitID, userID) + if err != nil { + writeError(w, "failed to fetch freezes", http.StatusInternalServerError) + return + } + + writeJSON(w, freezes, http.StatusOK) +} + +func (h *HabitFreezeHandler) Delete(w http.ResponseWriter, r *http.Request) { + userID := middleware.GetUserID(r.Context()) + freezeID, err := strconv.ParseInt(chi.URLParam(r, "freezeId"), 10, 64) + if err != nil { + writeError(w, "invalid freeze id", http.StatusBadRequest) + return + } + + if err := h.freezeRepo.Delete(freezeID, userID); err != nil { + if errors.Is(err, repository.ErrFreezeNotFound) { + writeError(w, "freeze not found", http.StatusNotFound) + return + } + writeError(w, "failed to delete freeze", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusNoContent) +} diff --git a/internal/handler/habits.go b/internal/handler/habits.go index df22054..0a0fc96 100644 --- a/internal/handler/habits.go +++ b/internal/handler/habits.go @@ -146,6 +146,14 @@ func (h *HabitHandler) Log(w http.ResponseWriter, r *http.Request) { writeError(w, "habit not found", http.StatusNotFound) return } + if errors.Is(err, service.ErrFutureDate) { + writeError(w, "cannot log habit for future date", http.StatusBadRequest) + return + } + if errors.Is(err, service.ErrAlreadyLogged) { + writeError(w, "habit already logged for this date", http.StatusConflict) + return + } writeError(w, "failed to log habit", http.StatusInternalServerError) return } diff --git a/internal/handler/savings.go b/internal/handler/savings.go new file mode 100644 index 0000000..1f98413 --- /dev/null +++ b/internal/handler/savings.go @@ -0,0 +1,489 @@ +package handler + +import ( + "encoding/json" + "errors" + "net/http" + "strconv" + + "github.com/go-chi/chi/v5" + + "github.com/daniil/homelab-api/internal/middleware" + "github.com/daniil/homelab-api/internal/model" + "github.com/daniil/homelab-api/internal/repository" +) + +type SavingsHandler struct { + repo *repository.SavingsRepository +} + +func NewSavingsHandler(repo *repository.SavingsRepository) *SavingsHandler { + return &SavingsHandler{repo: repo} +} + +// ==================== CATEGORIES ==================== + +func (h *SavingsHandler) ListCategories(w http.ResponseWriter, r *http.Request) { + userID := middleware.GetUserID(r.Context()) + + categories, err := h.repo.ListCategories(userID) + if err != nil { + writeError(w, "failed to fetch categories", http.StatusInternalServerError) + return + } + + // Process and calculate balances + for i := range categories { + categories[i].ProcessForJSON() + balance, _ := h.repo.GetCategoryBalance(categories[i].ID) + categories[i].CurrentAmount = balance + + if categories[i].IsRecurring { + total, _ := h.repo.GetRecurringTotalAmount(categories[i].ID) + categories[i].RecurringTotalAmount = total + } + } + + writeJSON(w, categories, http.StatusOK) +} + +func (h *SavingsHandler) GetCategory(w http.ResponseWriter, r *http.Request) { + userID := middleware.GetUserID(r.Context()) + categoryID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) + if err != nil { + writeError(w, "invalid category id", http.StatusBadRequest) + return + } + + category, err := h.repo.GetCategory(categoryID, userID) + if err != nil { + if errors.Is(err, repository.ErrCategoryNotFound) { + writeError(w, "category not found", http.StatusNotFound) + return + } + writeError(w, "failed to fetch category", http.StatusInternalServerError) + return + } + + category.ProcessForJSON() + balance, _ := h.repo.GetCategoryBalance(category.ID) + category.CurrentAmount = balance + + if category.IsRecurring { + total, _ := h.repo.GetRecurringTotalAmount(category.ID) + category.RecurringTotalAmount = total + } + + writeJSON(w, category, http.StatusOK) +} + +func (h *SavingsHandler) CreateCategory(w http.ResponseWriter, r *http.Request) { + userID := middleware.GetUserID(r.Context()) + + var req model.CreateSavingsCategoryRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, "invalid request body", http.StatusBadRequest) + return + } + + if req.Name == "" { + writeError(w, "name is required", http.StatusBadRequest) + return + } + + category, err := h.repo.CreateCategory(userID, &req) + if err != nil { + writeError(w, "failed to create category", http.StatusInternalServerError) + return + } + + category.ProcessForJSON() + writeJSON(w, category, http.StatusCreated) +} + +func (h *SavingsHandler) UpdateCategory(w http.ResponseWriter, r *http.Request) { + userID := middleware.GetUserID(r.Context()) + categoryID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) + if err != nil { + writeError(w, "invalid category id", http.StatusBadRequest) + return + } + + var req model.UpdateSavingsCategoryRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, "invalid request body", http.StatusBadRequest) + return + } + + category, err := h.repo.UpdateCategory(categoryID, userID, &req) + if err != nil { + if errors.Is(err, repository.ErrCategoryNotFound) { + writeError(w, "category not found", http.StatusNotFound) + return + } + if errors.Is(err, repository.ErrNotAuthorized) { + writeError(w, "not authorized", http.StatusForbidden) + return + } + writeError(w, "failed to update category", http.StatusInternalServerError) + return + } + + category.ProcessForJSON() + balance, _ := h.repo.GetCategoryBalance(category.ID) + category.CurrentAmount = balance + + writeJSON(w, category, http.StatusOK) +} + +func (h *SavingsHandler) DeleteCategory(w http.ResponseWriter, r *http.Request) { + userID := middleware.GetUserID(r.Context()) + categoryID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) + if err != nil { + writeError(w, "invalid category id", http.StatusBadRequest) + return + } + + if err := h.repo.DeleteCategory(categoryID, userID); err != nil { + if errors.Is(err, repository.ErrCategoryNotFound) { + writeError(w, "category not found", http.StatusNotFound) + return + } + writeError(w, "failed to delete category", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +// ==================== TRANSACTIONS ==================== + +func (h *SavingsHandler) ListTransactions(w http.ResponseWriter, r *http.Request) { + userID := middleware.GetUserID(r.Context()) + + var categoryID *int64 + if catIDStr := r.URL.Query().Get("category_id"); catIDStr != "" { + id, err := strconv.ParseInt(catIDStr, 10, 64) + if err == nil { + categoryID = &id + } + } + + limit := 100 + if l := r.URL.Query().Get("limit"); l != "" { + if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 { + limit = parsed + } + } + + offset := 0 + if o := r.URL.Query().Get("offset"); o != "" { + if parsed, err := strconv.Atoi(o); err == nil && parsed >= 0 { + offset = parsed + } + } + + transactions, err := h.repo.ListTransactions(userID, categoryID, limit, offset) + if err != nil { + writeError(w, "failed to fetch transactions", http.StatusInternalServerError) + return + } + + writeJSON(w, transactions, http.StatusOK) +} + +func (h *SavingsHandler) GetTransaction(w http.ResponseWriter, r *http.Request) { + userID := middleware.GetUserID(r.Context()) + txID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) + if err != nil { + writeError(w, "invalid transaction id", http.StatusBadRequest) + return + } + + transaction, err := h.repo.GetTransaction(txID, userID) + if err != nil { + if errors.Is(err, repository.ErrTransactionNotFound) { + writeError(w, "transaction not found", http.StatusNotFound) + return + } + writeError(w, "failed to fetch transaction", http.StatusInternalServerError) + return + } + + writeJSON(w, transaction, http.StatusOK) +} + +func (h *SavingsHandler) CreateTransaction(w http.ResponseWriter, r *http.Request) { + userID := middleware.GetUserID(r.Context()) + + var req model.CreateSavingsTransactionRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, "invalid request body", http.StatusBadRequest) + return + } + + if req.CategoryID == 0 { + writeError(w, "category_id is required", http.StatusBadRequest) + return + } + if req.Amount <= 0 { + writeError(w, "amount must be positive", http.StatusBadRequest) + return + } + if req.Type != "deposit" && req.Type != "withdrawal" { + writeError(w, "type must be 'deposit' or 'withdrawal'", http.StatusBadRequest) + return + } + if req.Date == "" { + writeError(w, "date is required", http.StatusBadRequest) + return + } + + transaction, err := h.repo.CreateTransaction(userID, &req) + if err != nil { + if errors.Is(err, repository.ErrCategoryNotFound) { + writeError(w, "category not found", http.StatusNotFound) + return + } + writeError(w, "failed to create transaction: "+err.Error(), http.StatusInternalServerError) + return + } + + writeJSON(w, transaction, http.StatusCreated) +} + +func (h *SavingsHandler) UpdateTransaction(w http.ResponseWriter, r *http.Request) { + userID := middleware.GetUserID(r.Context()) + txID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) + if err != nil { + writeError(w, "invalid transaction id", http.StatusBadRequest) + return + } + + var req model.UpdateSavingsTransactionRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, "invalid request body", http.StatusBadRequest) + return + } + + transaction, err := h.repo.UpdateTransaction(txID, userID, &req) + if err != nil { + if errors.Is(err, repository.ErrTransactionNotFound) { + writeError(w, "transaction not found", http.StatusNotFound) + return + } + if errors.Is(err, repository.ErrNotAuthorized) { + writeError(w, "not authorized", http.StatusForbidden) + return + } + writeError(w, "failed to update transaction", http.StatusInternalServerError) + return + } + + writeJSON(w, transaction, http.StatusOK) +} + +func (h *SavingsHandler) DeleteTransaction(w http.ResponseWriter, r *http.Request) { + userID := middleware.GetUserID(r.Context()) + txID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) + if err != nil { + writeError(w, "invalid transaction id", http.StatusBadRequest) + return + } + + if err := h.repo.DeleteTransaction(txID, userID); err != nil { + if errors.Is(err, repository.ErrTransactionNotFound) { + writeError(w, "transaction not found", http.StatusNotFound) + return + } + writeError(w, "failed to delete transaction", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +// ==================== STATS ==================== + +func (h *SavingsHandler) Stats(w http.ResponseWriter, r *http.Request) { + userID := middleware.GetUserID(r.Context()) + + stats, err := h.repo.GetStats(userID) + if err != nil { + writeError(w, "failed to fetch stats", http.StatusInternalServerError) + return + } + + writeJSON(w, stats, http.StatusOK) +} + +// ==================== MEMBERS ==================== + +func (h *SavingsHandler) ListMembers(w http.ResponseWriter, r *http.Request) { + categoryID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) + if err != nil { + writeError(w, "invalid category id", http.StatusBadRequest) + return + } + + members, err := h.repo.GetCategoryMembers(categoryID) + if err != nil { + writeError(w, "failed to fetch members", http.StatusInternalServerError) + return + } + + writeJSON(w, members, http.StatusOK) +} + +func (h *SavingsHandler) AddMember(w http.ResponseWriter, r *http.Request) { + userID := middleware.GetUserID(r.Context()) + categoryID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) + if err != nil { + writeError(w, "invalid category id", http.StatusBadRequest) + return + } + + // Check ownership + category, err := h.repo.GetCategory(categoryID, userID) + if err != nil { + writeError(w, "category not found", http.StatusNotFound) + return + } + if category.UserID != userID { + writeError(w, "not authorized", http.StatusForbidden) + return + } + + var req struct { + UserID int64 `json:"user_id"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, "invalid request body", http.StatusBadRequest) + return + } + + if err := h.repo.AddCategoryMember(categoryID, req.UserID); err != nil { + writeError(w, "failed to add member", http.StatusInternalServerError) + return + } + + members, _ := h.repo.GetCategoryMembers(categoryID) + writeJSON(w, members, http.StatusOK) +} + +func (h *SavingsHandler) RemoveMember(w http.ResponseWriter, r *http.Request) { + userID := middleware.GetUserID(r.Context()) + categoryID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) + if err != nil { + writeError(w, "invalid category id", http.StatusBadRequest) + return + } + memberUserID, err := strconv.ParseInt(chi.URLParam(r, "userId"), 10, 64) + if err != nil { + writeError(w, "invalid user id", http.StatusBadRequest) + return + } + + // Check ownership + category, err := h.repo.GetCategory(categoryID, userID) + if err != nil { + writeError(w, "category not found", http.StatusNotFound) + return + } + if category.UserID != userID { + writeError(w, "not authorized", http.StatusForbidden) + return + } + + if err := h.repo.RemoveCategoryMember(categoryID, memberUserID); err != nil { + writeError(w, "failed to remove member", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +// ==================== RECURRING PLANS ==================== + +func (h *SavingsHandler) ListRecurringPlans(w http.ResponseWriter, r *http.Request) { + categoryID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) + if err != nil { + writeError(w, "invalid category id", http.StatusBadRequest) + return + } + + plans, err := h.repo.ListRecurringPlans(categoryID) + if err != nil { + writeError(w, "failed to fetch recurring plans", http.StatusInternalServerError) + return + } + + writeJSON(w, plans, http.StatusOK) +} + +func (h *SavingsHandler) CreateRecurringPlan(w http.ResponseWriter, r *http.Request) { + userID := middleware.GetUserID(r.Context()) + categoryID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) + if err != nil { + writeError(w, "invalid category id", http.StatusBadRequest) + return + } + + // Check access + _, err = h.repo.GetCategory(categoryID, userID) + if err != nil { + writeError(w, "category not found", http.StatusNotFound) + return + } + + var req model.CreateRecurringPlanRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, "invalid request body", http.StatusBadRequest) + return + } + + plan, err := h.repo.CreateRecurringPlan(categoryID, &req) + if err != nil { + writeError(w, "failed to create recurring plan: "+err.Error(), http.StatusInternalServerError) + return + } + + writeJSON(w, plan, http.StatusCreated) +} + +func (h *SavingsHandler) DeleteRecurringPlan(w http.ResponseWriter, r *http.Request) { + planID, err := strconv.ParseInt(chi.URLParam(r, "planId"), 10, 64) + if err != nil { + writeError(w, "invalid plan id", http.StatusBadRequest) + return + } + + if err := h.repo.DeleteRecurringPlan(planID); err != nil { + writeError(w, "failed to delete recurring plan", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +func (h *SavingsHandler) UpdateRecurringPlan(w http.ResponseWriter, r *http.Request) { + planID, err := strconv.ParseInt(chi.URLParam(r, "planId"), 10, 64) + if err != nil { + writeError(w, "invalid plan id", http.StatusBadRequest) + return + } + + var req model.UpdateRecurringPlanRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, "invalid request body", http.StatusBadRequest) + return + } + + plan, err := h.repo.UpdateRecurringPlan(planID, &req) + if err != nil { + writeError(w, "failed to update recurring plan: "+err.Error(), http.StatusInternalServerError) + return + } + + writeJSON(w, plan, http.StatusOK) +} diff --git a/internal/model/habit.go b/internal/model/habit.go index 43cfc5f..f2f8c54 100644 --- a/internal/model/habit.go +++ b/internal/model/habit.go @@ -18,6 +18,8 @@ type Habit struct { TargetCount int `db:"target_count" json:"target_count"` ReminderTime sql.NullString `db:"reminder_time" json:"-"` ReminderTimeStr *string `db:"-" json:"reminder_time"` + StartDate sql.NullTime `db:"start_date" json:"-"` + StartDateStr *string `db:"-" json:"start_date"` IsArchived bool `db:"is_archived" json:"is_archived"` CreatedAt time.Time `db:"created_at" json:"created_at"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` @@ -42,6 +44,12 @@ func (h *Habit) ProcessForJSON() { h.ReminderTimeStr = &formatted } } + + // Process start_date + if h.StartDate.Valid { + formatted := h.StartDate.Time.Format("2006-01-02") + h.StartDateStr = &formatted + } } type HabitLog struct { @@ -63,6 +71,7 @@ type CreateHabitRequest struct { TargetDays []int `json:"target_days,omitempty"` TargetCount int `json:"target_count,omitempty"` ReminderTime *string `json:"reminder_time,omitempty"` + StartDate *string `json:"start_date,omitempty"` } type UpdateHabitRequest struct { @@ -74,6 +83,7 @@ type UpdateHabitRequest struct { TargetDays []int `json:"target_days,omitempty"` TargetCount *int `json:"target_count,omitempty"` ReminderTime *string `json:"reminder_time,omitempty"` + StartDate *string `json:"start_date,omitempty"` IsArchived *bool `json:"is_archived,omitempty"` } diff --git a/internal/model/habit_freeze.go b/internal/model/habit_freeze.go new file mode 100644 index 0000000..de51fb9 --- /dev/null +++ b/internal/model/habit_freeze.go @@ -0,0 +1,21 @@ +package model + +import ( + "time" +) + +type HabitFreeze struct { + ID int64 `db:"id" json:"id"` + HabitID int64 `db:"habit_id" json:"habit_id"` + UserID int64 `db:"user_id" json:"user_id"` + StartDate time.Time `db:"start_date" json:"start_date"` + EndDate time.Time `db:"end_date" json:"end_date"` + Reason string `db:"reason" json:"reason"` + CreatedAt time.Time `db:"created_at" json:"created_at"` +} + +type CreateHabitFreezeRequest struct { + StartDate string `json:"start_date"` + EndDate string `json:"end_date"` + Reason string `json:"reason,omitempty"` +} diff --git a/internal/model/savings.go b/internal/model/savings.go new file mode 100644 index 0000000..54a3afa --- /dev/null +++ b/internal/model/savings.go @@ -0,0 +1,228 @@ +package model + +import ( + "database/sql" + "time" +) + +type SavingsCategory struct { + ID int64 `db:"id" json:"id"` + UserID int64 `db:"user_id" json:"user_id"` + Name string `db:"name" json:"name"` + Description string `db:"description" json:"description"` + + // Type flags + IsDeposit bool `db:"is_deposit" json:"is_deposit"` + IsCredit bool `db:"is_credit" json:"is_credit"` + IsAccount bool `db:"is_account" json:"is_account"` + IsRecurring bool `db:"is_recurring" json:"is_recurring"` + IsMulti bool `db:"is_multi" json:"is_multi"` + IsClosed bool `db:"is_closed" json:"is_closed"` + + // Initial capital + InitialCapital float64 `db:"initial_capital" json:"initial_capital"` + + // Deposit fields + DepositAmount float64 `db:"deposit_amount" json:"deposit_amount"` + InterestRate float64 `db:"interest_rate" json:"interest_rate"` + DepositStartDate sql.NullTime `db:"deposit_start_date" json:"-"` + DepositStartStr *string `db:"-" json:"deposit_start_date"` + DepositTerm int `db:"deposit_term" json:"deposit_term"` + DepositEndDate sql.NullTime `db:"deposit_end_date" json:"-"` + DepositEndStr *string `db:"-" json:"deposit_end_date"` + LastInterestCalc sql.NullTime `db:"last_interest_calc" json:"-"` + FinalAmount float64 `db:"final_amount" json:"final_amount"` + + // Credit fields + CreditAmount float64 `db:"credit_amount" json:"credit_amount"` + CreditTerm int `db:"credit_term" json:"credit_term"` + CreditRate float64 `db:"credit_rate" json:"credit_rate"` + CreditStartDate sql.NullTime `db:"credit_start_date" json:"-"` + CreditStartStr *string `db:"-" json:"credit_start_date"` + + // Recurring fields + RecurringAmount float64 `db:"recurring_amount" json:"recurring_amount"` + RecurringDay int `db:"recurring_day" json:"recurring_day"` + RecurringStartDate sql.NullTime `db:"recurring_start_date" json:"-"` + LastRecurringRun sql.NullTime `db:"last_recurring_run" json:"-"` + + // Computed (populated in service) + CurrentAmount float64 `db:"-" json:"current_amount"` + RecurringTotalAmount float64 `db:"-" json:"recurring_total_amount"` + + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + + // Relations + Members []SavingsCategoryMember `db:"-" json:"members,omitempty"` +} + +func (c *SavingsCategory) ProcessForJSON() { + if c.DepositStartDate.Valid { + formatted := c.DepositStartDate.Time.Format("2006-01-02") + c.DepositStartStr = &formatted + } + if c.DepositEndDate.Valid { + formatted := c.DepositEndDate.Time.Format("2006-01-02") + c.DepositEndStr = &formatted + } + if c.CreditStartDate.Valid { + formatted := c.CreditStartDate.Time.Format("2006-01-02") + c.CreditStartStr = &formatted + } +} + +type SavingsTransaction struct { + ID int64 `db:"id" json:"id"` + CategoryID int64 `db:"category_id" json:"category_id"` + UserID int64 `db:"user_id" json:"user_id"` + Amount float64 `db:"amount" json:"amount"` + Type string `db:"type" json:"type"` // deposit, withdrawal + Description string `db:"description" json:"description"` + Date time.Time `db:"date" json:"date"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + CategoryName string `db:"category_name" json:"category_name,omitempty"` + UserName string `db:"user_name" json:"user_name,omitempty"` +} + +type SavingsRecurringPlan struct { + ID int64 `db:"id" json:"id"` + CategoryID int64 `db:"category_id" json:"category_id"` + UserID sql.NullInt64 `db:"user_id" json:"-"` + UserIDPtr *int64 `db:"-" json:"user_id"` + Effective time.Time `db:"effective" json:"effective"` + Amount float64 `db:"amount" json:"amount"` + Day int `db:"day" json:"day"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + +func (p *SavingsRecurringPlan) ProcessForJSON() { + if p.UserID.Valid { + p.UserIDPtr = &p.UserID.Int64 + } +} + +type SavingsCategoryMember struct { + ID int64 `db:"id" json:"id"` + CategoryID int64 `db:"category_id" json:"category_id"` + UserID int64 `db:"user_id" json:"user_id"` + UserName string `db:"user_name" json:"user_name,omitempty"` + CreatedAt time.Time `db:"created_at" json:"created_at"` +} + +// Request DTOs +type CreateSavingsCategoryRequest struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + IsDeposit bool `json:"is_deposit,omitempty"` + IsCredit bool `json:"is_credit,omitempty"` + IsAccount bool `json:"is_account,omitempty"` + IsRecurring bool `json:"is_recurring,omitempty"` + IsMulti bool `json:"is_multi,omitempty"` + InitialCapital float64 `json:"initial_capital,omitempty"` + DepositAmount float64 `json:"deposit_amount,omitempty"` + InterestRate float64 `json:"interest_rate,omitempty"` + DepositStartDate *string `json:"deposit_start_date,omitempty"` + DepositTerm int `json:"deposit_term,omitempty"` + CreditAmount float64 `json:"credit_amount,omitempty"` + CreditTerm int `json:"credit_term,omitempty"` + CreditRate float64 `json:"credit_rate,omitempty"` + CreditStartDate *string `json:"credit_start_date,omitempty"` + RecurringAmount float64 `json:"recurring_amount,omitempty"` + RecurringDay int `json:"recurring_day,omitempty"` + RecurringStartDate *string `json:"recurring_start_date,omitempty"` + MemberIDs []int64 `json:"member_ids,omitempty"` +} + +type UpdateSavingsCategoryRequest struct { + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + IsDeposit *bool `json:"is_deposit,omitempty"` + IsCredit *bool `json:"is_credit,omitempty"` + IsAccount *bool `json:"is_account,omitempty"` + IsRecurring *bool `json:"is_recurring,omitempty"` + IsMulti *bool `json:"is_multi,omitempty"` + IsClosed *bool `json:"is_closed,omitempty"` + InitialCapital *float64 `json:"initial_capital,omitempty"` + DepositAmount *float64 `json:"deposit_amount,omitempty"` + InterestRate *float64 `json:"interest_rate,omitempty"` + DepositStartDate *string `json:"deposit_start_date,omitempty"` + DepositTerm *int `json:"deposit_term,omitempty"` + FinalAmount *float64 `json:"final_amount,omitempty"` + CreditAmount *float64 `json:"credit_amount,omitempty"` + CreditTerm *int `json:"credit_term,omitempty"` + CreditRate *float64 `json:"credit_rate,omitempty"` + CreditStartDate *string `json:"credit_start_date,omitempty"` + RecurringAmount *float64 `json:"recurring_amount,omitempty"` + RecurringDay *int `json:"recurring_day,omitempty"` + RecurringStartDate *string `json:"recurring_start_date,omitempty"` +} + +type CreateSavingsTransactionRequest struct { + CategoryID int64 `json:"category_id"` + Amount float64 `json:"amount"` + Type string `json:"type"` // deposit, withdrawal + Description string `json:"description,omitempty"` + Date string `json:"date"` +} + +type UpdateSavingsTransactionRequest struct { + Amount *float64 `json:"amount,omitempty"` + Type *string `json:"type,omitempty"` + Description *string `json:"description,omitempty"` + Date *string `json:"date,omitempty"` +} + +type CreateRecurringPlanRequest struct { + Effective string `json:"effective"` + Amount float64 `json:"amount"` + Day int `json:"day,omitempty"` + UserID *int64 `json:"user_id,omitempty"` +} + +type UpdateRecurringPlanRequest struct { + Effective *string `json:"effective,omitempty"` + Amount *float64 `json:"amount,omitempty"` + Day *int `json:"day,omitempty"` +} + +type SavingsStats struct { + MonthlyPayments float64 `json:"monthly_payments"` + MonthlyPaymentDetails []MonthlyPaymentDetail `json:"monthly_payment_details"` + Overdues []OverduePayment `json:"overdues"` + TotalBalance float64 `json:"total_balance"` + TotalDeposits float64 `json:"total_deposits"` + TotalWithdrawals float64 `json:"total_withdrawals"` + CategoriesCount int `json:"categories_count"` + ByCategory []CategoryStats `json:"by_category"` +} + +type CategoryStats struct { + CategoryID int64 `json:"category_id"` + CategoryName string `json:"category_name"` + Balance float64 `json:"balance"` + IsDeposit bool `json:"is_deposit"` + IsRecurring bool `json:"is_recurring"` +} + +// MonthlyPaymentDetail represents a recurring payment detail +type MonthlyPaymentDetail struct { + CategoryID int64 `json:"category_id"` + CategoryName string `json:"category_name"` + Amount float64 `json:"amount"` + Day int `json:"day"` +} + +// OverduePayment represents an overdue recurring payment +type OverduePayment struct { + CategoryID int64 `json:"category_id"` + CategoryName string `json:"category_name"` + UserID int64 `json:"user_id"` + UserName string `json:"user_name"` + Amount float64 `json:"amount"` + DueDay int `json:"due_day"` + DaysOverdue int `json:"days_overdue"` + Month string `json:"month"` +} diff --git a/internal/model/task.go b/internal/model/task.go index acae010..ef81496 100644 --- a/internal/model/task.go +++ b/internal/model/task.go @@ -7,21 +7,30 @@ import ( ) type Task struct { - ID int64 `db:"id" json:"id"` - UserID int64 `db:"user_id" json:"user_id"` - Title string `db:"title" json:"title"` - Description string `db:"description" json:"description"` - Icon string `db:"icon" json:"icon"` - Color string `db:"color" json:"color"` - DueDate sql.NullTime `db:"due_date" json:"-"` - DueDateStr *string `db:"-" json:"due_date"` - Priority int `db:"priority" json:"priority"` - ReminderTime sql.NullString `db:"reminder_time" json:"-"` - ReminderTimeStr *string `db:"-" json:"reminder_time"` - CompletedAt sql.NullTime `db:"completed_at" json:"-"` - Completed bool `db:"-" json:"completed"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + ID int64 `db:"id" json:"id"` + UserID int64 `db:"user_id" json:"user_id"` + Title string `db:"title" json:"title"` + Description string `db:"description" json:"description"` + Icon string `db:"icon" json:"icon"` + Color string `db:"color" json:"color"` + DueDate sql.NullTime `db:"due_date" json:"-"` + DueDateStr *string `db:"-" json:"due_date"` + Priority int `db:"priority" json:"priority"` + ReminderTime sql.NullString `db:"reminder_time" json:"-"` + ReminderTimeStr *string `db:"-" json:"reminder_time"` + CompletedAt sql.NullTime `db:"completed_at" json:"-"` + Completed bool `db:"-" json:"completed"` + // Recurring task fields + IsRecurring bool `db:"is_recurring" json:"is_recurring"` + RecurrenceType sql.NullString `db:"recurrence_type" json:"-"` + RecurrenceTypeStr *string `db:"-" json:"recurrence_type"` + RecurrenceInterval int `db:"recurrence_interval" json:"recurrence_interval"` + RecurrenceEndDate sql.NullTime `db:"recurrence_end_date" json:"-"` + RecurrenceEndStr *string `db:"-" json:"recurrence_end_date"` + ParentTaskID sql.NullInt64 `db:"parent_task_id" json:"-"` + ParentTaskIDPtr *int64 `db:"-" json:"parent_task_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` } func (t *Task) ProcessForJSON() { @@ -47,24 +56,46 @@ func (t *Task) ProcessForJSON() { } } t.Completed = t.CompletedAt.Valid + + // Process recurring fields + if t.RecurrenceType.Valid { + t.RecurrenceTypeStr = &t.RecurrenceType.String + } + if t.RecurrenceEndDate.Valid { + s := t.RecurrenceEndDate.Time.Format("2006-01-02") + t.RecurrenceEndStr = &s + } + if t.ParentTaskID.Valid { + t.ParentTaskIDPtr = &t.ParentTaskID.Int64 + } } type CreateTaskRequest struct { - Title string `json:"title"` - Description string `json:"description,omitempty"` - Icon string `json:"icon,omitempty"` - Color string `json:"color,omitempty"` - DueDate *string `json:"due_date,omitempty"` - Priority int `json:"priority,omitempty"` - ReminderTime *string `json:"reminder_time,omitempty"` + Title string `json:"title"` + Description string `json:"description,omitempty"` + Icon string `json:"icon,omitempty"` + Color string `json:"color,omitempty"` + DueDate *string `json:"due_date,omitempty"` + Priority int `json:"priority,omitempty"` + ReminderTime *string `json:"reminder_time,omitempty"` + // Recurring fields + IsRecurring bool `json:"is_recurring,omitempty"` + RecurrenceType *string `json:"recurrence_type,omitempty"` + RecurrenceInterval int `json:"recurrence_interval,omitempty"` + RecurrenceEndDate *string `json:"recurrence_end_date,omitempty"` } type UpdateTaskRequest struct { - Title *string `json:"title,omitempty"` - Description *string `json:"description,omitempty"` - Icon *string `json:"icon,omitempty"` - Color *string `json:"color,omitempty"` - DueDate *string `json:"due_date,omitempty"` - Priority *int `json:"priority,omitempty"` - ReminderTime *string `json:"reminder_time,omitempty"` + Title *string `json:"title,omitempty"` + Description *string `json:"description,omitempty"` + Icon *string `json:"icon,omitempty"` + Color *string `json:"color,omitempty"` + DueDate *string `json:"due_date,omitempty"` + Priority *int `json:"priority,omitempty"` + ReminderTime *string `json:"reminder_time,omitempty"` + // Recurring fields + IsRecurring *bool `json:"is_recurring,omitempty"` + RecurrenceType *string `json:"recurrence_type,omitempty"` + RecurrenceInterval *int `json:"recurrence_interval,omitempty"` + RecurrenceEndDate *string `json:"recurrence_end_date,omitempty"` } diff --git a/internal/repository/db.go b/internal/repository/db.go index e6ac18a..311baa0 100644 --- a/internal/repository/db.go +++ b/internal/repository/db.go @@ -95,6 +95,28 @@ func RunMigrations(db *sqlx.DB) error { `ALTER TABLE tasks ADD COLUMN IF NOT EXISTS reminder_time TIME`, `ALTER TABLE habits ADD COLUMN IF NOT EXISTS reminder_time TIME`, `CREATE INDEX IF NOT EXISTS idx_users_telegram_chat_id ON users(telegram_chat_id)`, + // Recurring tasks support + `ALTER TABLE tasks ADD COLUMN IF NOT EXISTS is_recurring BOOLEAN DEFAULT false`, + `ALTER TABLE tasks ADD COLUMN IF NOT EXISTS recurrence_type VARCHAR(20)`, + `ALTER TABLE tasks ADD COLUMN IF NOT EXISTS recurrence_interval INTEGER DEFAULT 1`, + `ALTER TABLE tasks ADD COLUMN IF NOT EXISTS recurrence_end_date DATE`, + `ALTER TABLE tasks ADD COLUMN IF NOT EXISTS parent_task_id INTEGER REFERENCES tasks(id) ON DELETE SET NULL`, + `CREATE INDEX IF NOT EXISTS idx_tasks_parent_id ON tasks(parent_task_id)`, + `CREATE INDEX IF NOT EXISTS idx_tasks_recurring ON tasks(is_recurring) WHERE is_recurring = true`, + // Habit freezes support + `CREATE TABLE IF NOT EXISTS habit_freezes ( + id SERIAL PRIMARY KEY, + habit_id INTEGER REFERENCES habits(id) ON DELETE CASCADE, + user_id INTEGER REFERENCES users(id) ON DELETE CASCADE, + start_date DATE NOT NULL, + end_date DATE NOT NULL, + reason VARCHAR(255) DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + )`, + `CREATE INDEX IF NOT EXISTS idx_habit_freezes_habit ON habit_freezes(habit_id)`, + `CREATE INDEX IF NOT EXISTS idx_habit_freezes_dates ON habit_freezes(start_date, end_date)`, + // Habit start_date support + `ALTER TABLE habits ADD COLUMN IF NOT EXISTS start_date DATE`, } for _, migration := range migrations { diff --git a/internal/repository/habit.go b/internal/repository/habit.go index a995989..e2f514e 100644 --- a/internal/repository/habit.go +++ b/internal/repository/habit.go @@ -23,8 +23,8 @@ func NewHabitRepository(db *sqlx.DB) *HabitRepository { func (r *HabitRepository) Create(habit *model.Habit) error { query := ` - INSERT INTO habits (user_id, name, description, color, icon, frequency, target_days, target_count, reminder_time) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + INSERT INTO habits (user_id, name, description, color, icon, frequency, target_days, target_count, reminder_time, start_date) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at` targetDays := pq.Array(habit.TargetDays) @@ -42,6 +42,7 @@ func (r *HabitRepository) Create(habit *model.Habit) error { targetDays, habit.TargetCount, habit.ReminderTime, + habit.StartDate, ).Scan(&habit.ID, &habit.CreatedAt, &habit.UpdatedAt) } @@ -50,13 +51,13 @@ func (r *HabitRepository) GetByID(id, userID int64) (*model.Habit, error) { var targetDays pq.Int64Array query := ` - SELECT id, user_id, name, description, color, icon, frequency, target_days, target_count, reminder_time, is_archived, created_at, updated_at + SELECT id, user_id, name, description, color, icon, frequency, target_days, target_count, reminder_time, start_date, is_archived, created_at, updated_at FROM habits WHERE id = $1 AND user_id = $2` err := r.db.QueryRow(query, id, userID).Scan( &habit.ID, &habit.UserID, &habit.Name, &habit.Description, &habit.Color, &habit.Icon, &habit.Frequency, &targetDays, - &habit.TargetCount, &habit.ReminderTime, &habit.IsArchived, &habit.CreatedAt, &habit.UpdatedAt, + &habit.TargetCount, &habit.ReminderTime, &habit.StartDate, &habit.IsArchived, &habit.CreatedAt, &habit.UpdatedAt, ) if err != nil { @@ -77,7 +78,7 @@ func (r *HabitRepository) GetByID(id, userID int64) (*model.Habit, error) { func (r *HabitRepository) ListByUser(userID int64, includeArchived bool) ([]model.Habit, error) { query := ` - SELECT id, user_id, name, description, color, icon, frequency, target_days, target_count, reminder_time, is_archived, created_at, updated_at + SELECT id, user_id, name, description, color, icon, frequency, target_days, target_count, reminder_time, start_date, is_archived, created_at, updated_at FROM habits WHERE user_id = $1` if !includeArchived { @@ -99,7 +100,7 @@ func (r *HabitRepository) ListByUser(userID int64, includeArchived bool) ([]mode if err := rows.Scan( &habit.ID, &habit.UserID, &habit.Name, &habit.Description, &habit.Color, &habit.Icon, &habit.Frequency, &targetDays, - &habit.TargetCount, &habit.ReminderTime, &habit.IsArchived, &habit.CreatedAt, &habit.UpdatedAt, + &habit.TargetCount, &habit.ReminderTime, &habit.StartDate, &habit.IsArchived, &habit.CreatedAt, &habit.UpdatedAt, ); err != nil { return nil, err } @@ -118,14 +119,19 @@ func (r *HabitRepository) ListByUser(userID int64, includeArchived bool) ([]mode func (r *HabitRepository) GetHabitsWithReminder(reminderTime string, weekday int) ([]model.Habit, error) { query := ` - SELECT h.id, h.user_id, h.name, h.description, h.color, h.icon, h.frequency, h.target_days, h.target_count, h.reminder_time, h.is_archived, h.created_at, h.updated_at + SELECT h.id, h.user_id, h.name, h.description, h.color, h.icon, h.frequency, h.target_days, h.target_count, h.reminder_time, h.start_date, h.is_archived, h.created_at, h.updated_at FROM habits h JOIN users u ON h.user_id = u.id WHERE h.reminder_time = $1 AND h.is_archived = false - AND (h.frequency = 'daily' OR $2 = ANY(h.target_days)) AND u.telegram_chat_id IS NOT NULL - AND u.notifications_enabled = true` + AND u.notifications_enabled = true + AND ( + h.frequency = 'daily' + OR (h.frequency = 'weekly' AND $2 = ANY(h.target_days)) + OR h.frequency = 'interval' + OR h.frequency = 'custom' + )` rows, err := r.db.Query(query, reminderTime, weekday) if err != nil { @@ -141,7 +147,7 @@ func (r *HabitRepository) GetHabitsWithReminder(reminderTime string, weekday int if err := rows.Scan( &habit.ID, &habit.UserID, &habit.Name, &habit.Description, &habit.Color, &habit.Icon, &habit.Frequency, &targetDays, - &habit.TargetCount, &habit.ReminderTime, &habit.IsArchived, &habit.CreatedAt, &habit.UpdatedAt, + &habit.TargetCount, &habit.ReminderTime, &habit.StartDate, &habit.IsArchived, &habit.CreatedAt, &habit.UpdatedAt, ); err != nil { return nil, err } @@ -158,12 +164,41 @@ func (r *HabitRepository) GetHabitsWithReminder(reminderTime string, weekday int return habits, nil } +// ShouldShowIntervalHabitToday checks if an interval habit should be shown today +func (r *HabitRepository) ShouldShowIntervalHabitToday(habitID, userID int64, intervalDays int, startDate sql.NullTime) (bool, error) { + // Get the last log date for this habit + var lastLogDate sql.NullTime + err := r.db.Get(&lastLogDate, ` + SELECT MAX(date) FROM habit_logs WHERE habit_id = $1 AND user_id = $2 + `, habitID, userID) + + if err != nil && err != sql.ErrNoRows { + return false, err + } + + today := time.Now().Truncate(24 * time.Hour) + + // If no logs exist, check if today >= start_date (show on start_date) + if !lastLogDate.Valid { + if startDate.Valid { + return !today.Before(startDate.Time.Truncate(24*time.Hour)), nil + } + return true, nil + } + + // Calculate days since last log + lastLog := lastLogDate.Time.Truncate(24 * time.Hour) + daysSinceLastLog := int(today.Sub(lastLog).Hours() / 24) + + return daysSinceLastLog >= intervalDays, nil +} + func (r *HabitRepository) Update(habit *model.Habit) error { query := ` UPDATE habits SET name = $2, description = $3, color = $4, icon = $5, frequency = $6, - target_days = $7, target_count = $8, reminder_time = $9, is_archived = $10, updated_at = CURRENT_TIMESTAMP - WHERE id = $1 AND user_id = $11 + target_days = $7, target_count = $8, reminder_time = $9, start_date = $10, is_archived = $11, updated_at = CURRENT_TIMESTAMP + WHERE id = $1 AND user_id = $12 RETURNING updated_at` return r.db.QueryRow(query, @@ -176,6 +211,7 @@ func (r *HabitRepository) Update(habit *model.Habit) error { pq.Array(habit.TargetDays), habit.TargetCount, habit.ReminderTime, + habit.StartDate, habit.IsArchived, habit.UserID, ).Scan(&habit.UpdatedAt) @@ -264,40 +300,115 @@ func (r *HabitRepository) IsHabitCompletedToday(habitID, userID int64) (bool, er func (r *HabitRepository) GetStats(habitID, userID int64) (*model.HabitStats, error) { stats := &model.HabitStats{HabitID: habitID} + // Get habit info + habit, err := r.GetByID(habitID, userID) + if err != nil { + return nil, err + } + // Total logs r.db.Get(&stats.TotalLogs, `SELECT COUNT(*) FROM habit_logs WHERE habit_id = $1 AND user_id = $2`, habitID, userID) - // This week - weekStart := time.Now().AddDate(0, 0, -int(time.Now().Weekday())) + // This week (Monday-based) + now := time.Now() + weekday := int(now.Weekday()) + if weekday == 0 { + weekday = 7 + } + weekStart := now.AddDate(0, 0, -(weekday - 1)).Truncate(24 * time.Hour) r.db.Get(&stats.ThisWeek, `SELECT COUNT(*) FROM habit_logs WHERE habit_id = $1 AND user_id = $2 AND date >= $3`, habitID, userID, weekStart) // This month - monthStart := time.Date(time.Now().Year(), time.Now().Month(), 1, 0, 0, 0, 0, time.UTC) + monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC) r.db.Get(&stats.ThisMonth, `SELECT COUNT(*) FROM habit_logs WHERE habit_id = $1 AND user_id = $2 AND date >= $3`, habitID, userID, monthStart) - // Streaks calculation - stats.CurrentStreak, stats.LongestStreak = r.calculateStreaks(habitID, userID) + // Streaks calculation (respecting target_days and interval) + stats.CurrentStreak, stats.LongestStreak = r.calculateStreaksWithDays(habitID, userID, habit.Frequency, habit.TargetDays, habit.TargetCount) + + // Completion percentage since habit creation/start_date + stats.CompletionPct = r.calculateCompletionPct(habit, stats.TotalLogs) return stats, nil } -func (r *HabitRepository) calculateStreaks(habitID, userID int64) (current, longest int) { +// calculateStreaksWithDays counts consecutive completions on expected days +func (r *HabitRepository) calculateStreaksWithDays(habitID, userID int64, frequency string, targetDays []int, targetCount int) (current, longest int) { query := `SELECT date FROM habit_logs WHERE habit_id = $1 AND user_id = $2 ORDER BY date DESC` - var dates []time.Time - if err := r.db.Select(&dates, query, habitID, userID); err != nil || len(dates) == 0 { + var logDates []time.Time + if err := r.db.Select(&logDates, query, habitID, userID); err != nil || len(logDates) == 0 { + return 0, 0 + } + + // Convert log dates to map for quick lookup + logMap := make(map[string]bool) + for _, d := range logDates { + logMap[d.Format("2006-01-02")] = true + } + + // For interval habits, calculate streaks differently + if (frequency == "interval" || frequency == "custom") && targetCount > 0 { + return r.calculateIntervalStreaks(logDates, targetCount) + } + + // Generate expected days from today backwards + today := time.Now().Truncate(24 * time.Hour) + expectedDays := r.getExpectedDays(today, frequency, targetDays, 365) // Look back up to 1 year + + if len(expectedDays) == 0 { + return 0, 0 + } + + // Current streak: count from most recent expected day + current = 0 + for _, day := range expectedDays { + if logMap[day.Format("2006-01-02")] { + current++ + } else { + break + } + } + + // Longest streak + longest = 0 + streak := 0 + for _, day := range expectedDays { + if logMap[day.Format("2006-01-02")] { + streak++ + if streak > longest { + longest = streak + } + } else { + streak = 0 + } + } + + return current, longest +} + +// calculateIntervalStreaks calculates streaks for interval-based habits +func (r *HabitRepository) calculateIntervalStreaks(logDates []time.Time, intervalDays int) (current, longest int) { + if len(logDates) == 0 { return 0, 0 } today := time.Now().Truncate(24 * time.Hour) - yesterday := today.AddDate(0, 0, -1) - - // Current streak - if dates[0].Truncate(24*time.Hour).Equal(today) || dates[0].Truncate(24*time.Hour).Equal(yesterday) { + + // Check if the most recent log is within the interval window from today + lastLogDate := logDates[0].Truncate(24 * time.Hour) + daysSinceLastLog := int(today.Sub(lastLogDate).Hours() / 24) + + // Current streak: if we're within interval, count consecutive logs that are within interval of each other + current = 0 + if daysSinceLastLog < intervalDays { current = 1 - for i := 1; i < len(dates); i++ { - expected := dates[i-1].AddDate(0, 0, -1).Truncate(24 * time.Hour) - if dates[i].Truncate(24 * time.Hour).Equal(expected) { + for i := 1; i < len(logDates); i++ { + prevDate := logDates[i-1].Truncate(24 * time.Hour) + currDate := logDates[i].Truncate(24 * time.Hour) + daysBetween := int(prevDate.Sub(currDate).Hours() / 24) + + // If the gap is exactly the interval (or less, if done early), continue streak + if daysBetween <= intervalDays { current++ } else { break @@ -305,12 +416,15 @@ func (r *HabitRepository) calculateStreaks(habitID, userID int64) (current, long } } - // Longest streak - streak := 1 + // Longest streak calculation longest = 1 - for i := 1; i < len(dates); i++ { - expected := dates[i-1].AddDate(0, 0, -1).Truncate(24 * time.Hour) - if dates[i].Truncate(24 * time.Hour).Equal(expected) { + streak := 1 + for i := 1; i < len(logDates); i++ { + prevDate := logDates[i-1].Truncate(24 * time.Hour) + currDate := logDates[i].Truncate(24 * time.Hour) + daysBetween := int(prevDate.Sub(currDate).Hours() / 24) + + if daysBetween <= intervalDays { streak++ if streak > longest { longest = streak @@ -322,3 +436,88 @@ func (r *HabitRepository) calculateStreaks(habitID, userID int64) (current, long return current, longest } + +// getExpectedDays returns a list of days when the habit should be done, sorted descending +func (r *HabitRepository) getExpectedDays(from time.Time, frequency string, targetDays []int, maxDays int) []time.Time { + var result []time.Time + + for i := 0; i < maxDays; i++ { + day := from.AddDate(0, 0, -i) + + if frequency == "daily" { + result = append(result, day) + } else if frequency == "weekly" && len(targetDays) > 0 { + weekday := int(day.Weekday()) + if weekday == 0 { + weekday = 7 // Sunday = 7 + } + for _, td := range targetDays { + if td == weekday { + result = append(result, day) + break + } + } + } else { + result = append(result, day) + } + } + + return result +} + +// calculateCompletionPct calculates completion percentage since habit start_date (or created_at) +func (r *HabitRepository) calculateCompletionPct(habit *model.Habit, totalLogs int) float64 { + if totalLogs == 0 { + return 0 + } + + // Use start_date if set, otherwise use created_at + var startDate time.Time + if habit.StartDate.Valid { + startDate = habit.StartDate.Time.Truncate(24 * time.Hour) + } else { + startDate = habit.CreatedAt.Truncate(24 * time.Hour) + } + today := time.Now().Truncate(24 * time.Hour) + + expectedCount := 0 + + // For interval habits, calculate expected differently + if (habit.Frequency == "interval" || habit.Frequency == "custom") && habit.TargetCount > 0 { + // Expected = (days since start) / interval + 1 + daysSinceStart := int(today.Sub(startDate).Hours()/24) + 1 + expectedCount = (daysSinceStart / habit.TargetCount) + 1 + } else { + for d := startDate; !d.After(today); d = d.AddDate(0, 0, 1) { + if habit.Frequency == "daily" { + expectedCount++ + } else if habit.Frequency == "weekly" && len(habit.TargetDays) > 0 { + weekday := int(d.Weekday()) + if weekday == 0 { + weekday = 7 + } + for _, td := range habit.TargetDays { + if td == weekday { + expectedCount++ + break + } + } + } else { + expectedCount++ + } + } + } + + if expectedCount == 0 { + return 0 + } + + return float64(totalLogs) / float64(expectedCount) * 100 +} + +func (r *HabitRepository) IsHabitCompletedOnDate(habitID, userID int64, date time.Time) (bool, error) { + dateStr := date.Format("2006-01-02") + var count int + err := r.db.Get(&count, `SELECT COUNT(*) FROM habit_logs WHERE habit_id = $1 AND user_id = $2 AND date = $3`, habitID, userID, dateStr) + return count > 0, err +} diff --git a/internal/repository/habit_freeze.go b/internal/repository/habit_freeze.go new file mode 100644 index 0000000..8a15e41 --- /dev/null +++ b/internal/repository/habit_freeze.go @@ -0,0 +1,151 @@ +package repository + +import ( + "errors" + "time" + + "github.com/daniil/homelab-api/internal/model" + "github.com/jmoiron/sqlx" +) + +var ErrFreezeNotFound = errors.New("freeze not found") +var ErrInvalidDateRange = errors.New("invalid date range") + +type HabitFreezeRepository struct { + db *sqlx.DB +} + +func NewHabitFreezeRepository(db *sqlx.DB) *HabitFreezeRepository { + return &HabitFreezeRepository{db: db} +} + +func (r *HabitFreezeRepository) Create(freeze *model.HabitFreeze) error { + // Validate date range + if freeze.EndDate.Before(freeze.StartDate) { + return ErrInvalidDateRange + } + + query := ` + INSERT INTO habit_freezes (habit_id, user_id, start_date, end_date, reason) + VALUES ($1, $2, $3, $4, $5) + RETURNING id, created_at` + + return r.db.QueryRow(query, + freeze.HabitID, + freeze.UserID, + freeze.StartDate, + freeze.EndDate, + freeze.Reason, + ).Scan(&freeze.ID, &freeze.CreatedAt) +} + +func (r *HabitFreezeRepository) GetByHabitID(habitID, userID int64) ([]model.HabitFreeze, error) { + query := ` + SELECT id, habit_id, user_id, start_date, end_date, reason, created_at + FROM habit_freezes + WHERE habit_id = $1 AND user_id = $2 + ORDER BY start_date DESC` + + var freezes []model.HabitFreeze + if err := r.db.Select(&freezes, query, habitID, userID); err != nil { + return nil, err + } + + if freezes == nil { + freezes = []model.HabitFreeze{} + } + + return freezes, nil +} + +func (r *HabitFreezeRepository) GetActiveForHabit(habitID int64, date time.Time) (*model.HabitFreeze, error) { + query := ` + SELECT id, habit_id, user_id, start_date, end_date, reason, created_at + FROM habit_freezes + WHERE habit_id = $1 AND start_date <= $2 AND end_date >= $2` + + var freeze model.HabitFreeze + err := r.db.Get(&freeze, query, habitID, date) + if err != nil { + return nil, err + } + + return &freeze, nil +} + +func (r *HabitFreezeRepository) IsHabitFrozenOnDate(habitID int64, date time.Time) (bool, error) { + query := ` + SELECT COUNT(*) FROM habit_freezes + WHERE habit_id = $1 AND start_date <= $2 AND end_date >= $2` + + var count int + err := r.db.Get(&count, query, habitID, date) + if err != nil { + return false, err + } + + return count > 0, nil +} + +func (r *HabitFreezeRepository) GetFreezesForDateRange(habitID int64, startDate, endDate time.Time) ([]model.HabitFreeze, error) { + query := ` + SELECT id, habit_id, user_id, start_date, end_date, reason, created_at + FROM habit_freezes + WHERE habit_id = $1 + AND NOT (end_date < $2 OR start_date > $3) + ORDER BY start_date` + + var freezes []model.HabitFreeze + if err := r.db.Select(&freezes, query, habitID, startDate, endDate); err != nil { + return nil, err + } + + if freezes == nil { + freezes = []model.HabitFreeze{} + } + + return freezes, nil +} + +func (r *HabitFreezeRepository) Delete(freezeID, userID int64) error { + query := `DELETE FROM habit_freezes WHERE id = $1 AND user_id = $2` + result, err := r.db.Exec(query, freezeID, userID) + if err != nil { + return err + } + + rows, _ := result.RowsAffected() + if rows == 0 { + return ErrFreezeNotFound + } + + return nil +} + +func (r *HabitFreezeRepository) CountFrozenDaysInRange(habitID int64, startDate, endDate time.Time) (int, error) { + freezes, err := r.GetFreezesForDateRange(habitID, startDate, endDate) + if err != nil { + return 0, err + } + + frozenDays := 0 + for _, freeze := range freezes { + // Calculate overlap between freeze period and query range + overlapStart := freeze.StartDate + if startDate.After(freeze.StartDate) { + overlapStart = startDate + } + + overlapEnd := freeze.EndDate + if endDate.Before(freeze.EndDate) { + overlapEnd = endDate + } + + if !overlapEnd.Before(overlapStart) { + days := int(overlapEnd.Sub(overlapStart).Hours()/24) + 1 + frozenDays += days + } + } + + return frozenDays, nil +} diff --git a/internal/repository/savings.go b/internal/repository/savings.go new file mode 100644 index 0000000..1195ba5 --- /dev/null +++ b/internal/repository/savings.go @@ -0,0 +1,1045 @@ +package repository + +import ( + "database/sql" + "errors" + "fmt" + "time" + + "github.com/jmoiron/sqlx" + + "github.com/daniil/homelab-api/internal/model" +) + +var ( + ErrCategoryNotFound = errors.New("category not found") + ErrTransactionNotFound = errors.New("transaction not found") + ErrNotAuthorized = errors.New("not authorized") +) + +type SavingsRepository struct { + db *sqlx.DB +} + +func NewSavingsRepository(db *sqlx.DB) *SavingsRepository { + return &SavingsRepository{db: db} +} + +// ==================== CATEGORIES ==================== + +func (r *SavingsRepository) ListCategories(userID int64) ([]model.SavingsCategory, error) { + query := ` + SELECT DISTINCT c.* FROM savings_categories c + LEFT JOIN savings_category_members m ON c.id = m.category_id + WHERE c.user_id = $1 OR m.user_id = $1 + ORDER BY c.created_at DESC + ` + + var categories []model.SavingsCategory + err := r.db.Select(&categories, query, userID) + if err != nil { + return nil, err + } + + // Load members for each category + for i := range categories { + members, err := r.GetCategoryMembers(categories[i].ID) + if err == nil { + categories[i].Members = members + } + } + + return categories, nil +} + +func (r *SavingsRepository) GetCategory(id, userID int64) (*model.SavingsCategory, error) { + query := ` + SELECT DISTINCT c.* FROM savings_categories c + LEFT JOIN savings_category_members m ON c.id = m.category_id + WHERE c.id = $1 AND (c.user_id = $2 OR m.user_id = $2) + ` + + var category model.SavingsCategory + err := r.db.Get(&category, query, id, userID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrCategoryNotFound + } + return nil, err + } + + // Load members + members, err := r.GetCategoryMembers(id) + if err == nil { + category.Members = members + } + + return &category, nil +} + +func (r *SavingsRepository) CreateCategory(userID int64, req *model.CreateSavingsCategoryRequest) (*model.SavingsCategory, error) { + query := ` + INSERT INTO savings_categories ( + user_id, name, description, is_deposit, is_credit, is_account, is_recurring, is_multi, + initial_capital, deposit_amount, interest_rate, deposit_start_date, deposit_term, + credit_amount, credit_term, credit_rate, credit_start_date, + recurring_amount, recurring_day, recurring_start_date + ) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20 + ) RETURNING id, created_at, updated_at + ` + + var depositStartDate, creditStartDate, recurringStartDate *time.Time + if req.DepositStartDate != nil { + t, _ := time.Parse("2006-01-02", *req.DepositStartDate) + depositStartDate = &t + } + if req.CreditStartDate != nil { + t, _ := time.Parse("2006-01-02", *req.CreditStartDate) + creditStartDate = &t + } + if req.RecurringStartDate != nil { + t, _ := time.Parse("2006-01-02", *req.RecurringStartDate) + recurringStartDate = &t + } + + category := &model.SavingsCategory{ + UserID: userID, + Name: req.Name, + Description: req.Description, + IsDeposit: req.IsDeposit, + IsCredit: req.IsCredit, + IsAccount: req.IsAccount, + IsRecurring: req.IsRecurring, + IsMulti: req.IsMulti, + InitialCapital: req.InitialCapital, + DepositAmount: req.DepositAmount, + InterestRate: req.InterestRate, + DepositTerm: req.DepositTerm, + CreditAmount: req.CreditAmount, + CreditTerm: req.CreditTerm, + CreditRate: req.CreditRate, + RecurringAmount: req.RecurringAmount, + RecurringDay: req.RecurringDay, + } + + err := r.db.QueryRow( + query, + userID, req.Name, req.Description, req.IsDeposit, req.IsCredit, req.IsAccount, req.IsRecurring, req.IsMulti, + req.InitialCapital, req.DepositAmount, req.InterestRate, depositStartDate, req.DepositTerm, + req.CreditAmount, req.CreditTerm, req.CreditRate, creditStartDate, + req.RecurringAmount, req.RecurringDay, recurringStartDate, + ).Scan(&category.ID, &category.CreatedAt, &category.UpdatedAt) + + if err != nil { + return nil, err + } + + // Add creator as member if multi + if req.IsMulti { + r.AddCategoryMember(category.ID, userID) + for _, memberID := range req.MemberIDs { + if memberID != userID { + r.AddCategoryMember(category.ID, memberID) + } + } + } + + return r.GetCategory(category.ID, userID) +} + +func (r *SavingsRepository) UpdateCategory(id, userID int64, req *model.UpdateSavingsCategoryRequest) (*model.SavingsCategory, error) { + // Check ownership + var ownerID int64 + err := r.db.Get(&ownerID, "SELECT user_id FROM savings_categories WHERE id = $1", id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrCategoryNotFound + } + return nil, err + } + if ownerID != userID { + return nil, ErrNotAuthorized + } + + // Build dynamic update query + updates := []string{} + args := []interface{}{} + argNum := 1 + + if req.Name != nil { + updates = append(updates, fmt.Sprintf("name = $%d", argNum)) + args = append(args, *req.Name) + argNum++ + } + if req.Description != nil { + updates = append(updates, fmt.Sprintf("description = $%d", argNum)) + args = append(args, *req.Description) + argNum++ + } + if req.IsDeposit != nil { + updates = append(updates, fmt.Sprintf("is_deposit = $%d", argNum)) + args = append(args, *req.IsDeposit) + argNum++ + } + if req.IsCredit != nil { + updates = append(updates, fmt.Sprintf("is_credit = $%d", argNum)) + args = append(args, *req.IsCredit) + argNum++ + } + if req.IsAccount != nil { + updates = append(updates, fmt.Sprintf("is_account = $%d", argNum)) + args = append(args, *req.IsAccount) + argNum++ + } + if req.IsRecurring != nil { + updates = append(updates, fmt.Sprintf("is_recurring = $%d", argNum)) + args = append(args, *req.IsRecurring) + argNum++ + } + if req.IsMulti != nil { + updates = append(updates, fmt.Sprintf("is_multi = $%d", argNum)) + args = append(args, *req.IsMulti) + argNum++ + } + if req.IsClosed != nil { + updates = append(updates, fmt.Sprintf("is_closed = $%d", argNum)) + args = append(args, *req.IsClosed) + argNum++ + } + if req.InitialCapital != nil { + updates = append(updates, fmt.Sprintf("initial_capital = $%d", argNum)) + args = append(args, *req.InitialCapital) + argNum++ + } + if req.DepositAmount != nil { + updates = append(updates, fmt.Sprintf("deposit_amount = $%d", argNum)) + args = append(args, *req.DepositAmount) + argNum++ + } + if req.InterestRate != nil { + updates = append(updates, fmt.Sprintf("interest_rate = $%d", argNum)) + args = append(args, *req.InterestRate) + argNum++ + } + if req.DepositStartDate != nil { + t, _ := time.Parse("2006-01-02", *req.DepositStartDate) + updates = append(updates, fmt.Sprintf("deposit_start_date = $%d", argNum)) + args = append(args, t) + argNum++ + } + if req.DepositTerm != nil { + updates = append(updates, fmt.Sprintf("deposit_term = $%d", argNum)) + args = append(args, *req.DepositTerm) + argNum++ + } + if req.FinalAmount != nil { + updates = append(updates, fmt.Sprintf("final_amount = $%d", argNum)) + args = append(args, *req.FinalAmount) + argNum++ + } + if req.RecurringAmount != nil { + updates = append(updates, fmt.Sprintf("recurring_amount = $%d", argNum)) + args = append(args, *req.RecurringAmount) + argNum++ + } + if req.RecurringDay != nil { + updates = append(updates, fmt.Sprintf("recurring_day = $%d", argNum)) + args = append(args, *req.RecurringDay) + argNum++ + } + + if len(updates) == 0 { + return r.GetCategory(id, userID) + } + + updates = append(updates, "updated_at = NOW()") + args = append(args, id) + + query := fmt.Sprintf("UPDATE savings_categories SET %s WHERE id = $%d", + joinStrings(updates, ", "), argNum) + + _, err = r.db.Exec(query, args...) + if err != nil { + return nil, err + } + + return r.GetCategory(id, userID) +} + +func joinStrings(s []string, sep string) string { + if len(s) == 0 { + return "" + } + result := s[0] + for i := 1; i < len(s); i++ { + result += sep + s[i] + } + return result +} + +func (r *SavingsRepository) DeleteCategory(id, userID int64) error { + result, err := r.db.Exec("DELETE FROM savings_categories WHERE id = $1 AND user_id = $2", id, userID) + if err != nil { + return err + } + + rows, _ := result.RowsAffected() + if rows == 0 { + return ErrCategoryNotFound + } + + return nil +} + +// ==================== TRANSACTIONS ==================== + +func (r *SavingsRepository) ListTransactions(userID int64, categoryID *int64, limit, offset int) ([]model.SavingsTransaction, error) { + query := ` + SELECT DISTINCT ON (t.id) t.*, c.name as category_name, u.username as user_name + FROM savings_transactions t + JOIN savings_categories c ON t.category_id = c.id + JOIN users u ON t.user_id = u.id + LEFT JOIN savings_category_members m ON c.id = m.category_id + WHERE (c.user_id = $1 OR m.user_id = $1) + ` + args := []interface{}{userID} + argNum := 2 + + if categoryID != nil { + query += fmt.Sprintf(" AND t.category_id = $%d", argNum) + args = append(args, *categoryID) + argNum++ + } + + query += " ORDER BY t.id, t.date DESC" + + // Wrap for proper ordering + query = fmt.Sprintf("SELECT * FROM (%s) sub ORDER BY date DESC, id DESC", query) + + if limit > 0 { + query += fmt.Sprintf(" LIMIT $%d", argNum) + args = append(args, limit) + argNum++ + } + if offset > 0 { + query += fmt.Sprintf(" OFFSET $%d", argNum) + args = append(args, offset) + } + + var transactions []model.SavingsTransaction + err := r.db.Select(&transactions, query, args...) + return transactions, err +} + +func (r *SavingsRepository) GetTransaction(id, userID int64) (*model.SavingsTransaction, error) { + query := ` + SELECT t.*, c.name as category_name, u.username as user_name + FROM savings_transactions t + JOIN savings_categories c ON t.category_id = c.id + JOIN users u ON t.user_id = u.id + LEFT JOIN savings_category_members m ON c.id = m.category_id + WHERE t.id = $1 AND (c.user_id = $2 OR m.user_id = $2 OR t.user_id = $2) + ` + + var transaction model.SavingsTransaction + err := r.db.Get(&transaction, query, id, userID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrTransactionNotFound + } + return nil, err + } + return &transaction, nil +} + +func (r *SavingsRepository) CreateTransaction(userID int64, req *model.CreateSavingsTransactionRequest) (*model.SavingsTransaction, error) { + // Verify user has access to category + query := ` + SELECT 1 FROM savings_categories c + LEFT JOIN savings_category_members m ON c.id = m.category_id + WHERE c.id = $1 AND (c.user_id = $2 OR m.user_id = $2) + ` + var exists int + err := r.db.Get(&exists, query, req.CategoryID, userID) + if err != nil { + return nil, ErrCategoryNotFound + } + + date, err := time.Parse("2006-01-02", req.Date) + if err != nil { + return nil, fmt.Errorf("invalid date format") + } + + insertQuery := ` + INSERT INTO savings_transactions (category_id, user_id, amount, type, description, date) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING id, created_at, updated_at + ` + + tx := &model.SavingsTransaction{ + CategoryID: req.CategoryID, + UserID: userID, + Amount: req.Amount, + Type: req.Type, + Description: req.Description, + Date: date, + } + + err = r.db.QueryRow(insertQuery, req.CategoryID, userID, req.Amount, req.Type, req.Description, date). + Scan(&tx.ID, &tx.CreatedAt, &tx.UpdatedAt) + if err != nil { + return nil, err + } + + return r.GetTransaction(tx.ID, userID) +} + +func (r *SavingsRepository) UpdateTransaction(id, userID int64, req *model.UpdateSavingsTransactionRequest) (*model.SavingsTransaction, error) { + // Verify ownership + var txUserID int64 + err := r.db.Get(&txUserID, "SELECT user_id FROM savings_transactions WHERE id = $1", id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrTransactionNotFound + } + return nil, err + } + if txUserID != userID { + return nil, ErrNotAuthorized + } + + updates := []string{} + args := []interface{}{} + argNum := 1 + + if req.Amount != nil { + updates = append(updates, fmt.Sprintf("amount = $%d", argNum)) + args = append(args, *req.Amount) + argNum++ + } + if req.Type != nil { + updates = append(updates, fmt.Sprintf("type = $%d", argNum)) + args = append(args, *req.Type) + argNum++ + } + if req.Description != nil { + updates = append(updates, fmt.Sprintf("description = $%d", argNum)) + args = append(args, *req.Description) + argNum++ + } + if req.Date != nil { + date, _ := time.Parse("2006-01-02", *req.Date) + updates = append(updates, fmt.Sprintf("date = $%d", argNum)) + args = append(args, date) + argNum++ + } + + if len(updates) == 0 { + return r.GetTransaction(id, userID) + } + + updates = append(updates, "updated_at = NOW()") + args = append(args, id) + + query := fmt.Sprintf("UPDATE savings_transactions SET %s WHERE id = $%d", + joinStrings(updates, ", "), argNum) + + _, err = r.db.Exec(query, args...) + if err != nil { + return nil, err + } + + return r.GetTransaction(id, userID) +} + +func (r *SavingsRepository) DeleteTransaction(id, userID int64) error { + result, err := r.db.Exec("DELETE FROM savings_transactions WHERE id = $1 AND user_id = $2", id, userID) + if err != nil { + return err + } + + rows, _ := result.RowsAffected() + if rows == 0 { + return ErrTransactionNotFound + } + + return nil +} + +// ==================== CATEGORY MEMBERS ==================== + +func (r *SavingsRepository) GetCategoryMembers(categoryID int64) ([]model.SavingsCategoryMember, error) { + query := ` + SELECT m.*, u.username as user_name + FROM savings_category_members m + JOIN users u ON m.user_id = u.id + WHERE m.category_id = $1 + ` + var members []model.SavingsCategoryMember + err := r.db.Select(&members, query, categoryID) + return members, err +} + +func (r *SavingsRepository) AddCategoryMember(categoryID, userID int64) error { + _, err := r.db.Exec(` + INSERT INTO savings_category_members (category_id, user_id) + VALUES ($1, $2) + ON CONFLICT (category_id, user_id) DO NOTHING + `, categoryID, userID) + return err +} + +func (r *SavingsRepository) RemoveCategoryMember(categoryID, userID int64) error { + _, err := r.db.Exec("DELETE FROM savings_category_members WHERE category_id = $1 AND user_id = $2", categoryID, userID) + return err +} + +// ==================== RECURRING PLANS ==================== + +func (r *SavingsRepository) ListRecurringPlans(categoryID int64) ([]model.SavingsRecurringPlan, error) { + query := "SELECT * FROM savings_recurring_plans WHERE category_id = $1 ORDER BY effective DESC" + var plans []model.SavingsRecurringPlan + err := r.db.Select(&plans, query, categoryID) + if err != nil { + return nil, err + } + for i := range plans { + plans[i].ProcessForJSON() + } + return plans, nil +} + +func (r *SavingsRepository) CreateRecurringPlan(categoryID int64, req *model.CreateRecurringPlanRequest) (*model.SavingsRecurringPlan, error) { + effective, err := time.Parse("2006-01-02", req.Effective) + if err != nil { + return nil, fmt.Errorf("invalid effective date") + } + + day := req.Day + if day == 0 { + day = 1 + } + + query := ` + INSERT INTO savings_recurring_plans (category_id, user_id, effective, amount, day) + VALUES ($1, $2, $3, $4, $5) + RETURNING id, created_at, updated_at + ` + + plan := &model.SavingsRecurringPlan{ + CategoryID: categoryID, + Effective: effective, + Amount: req.Amount, + Day: day, + } + + err = r.db.QueryRow(query, categoryID, req.UserID, effective, req.Amount, day). + Scan(&plan.ID, &plan.CreatedAt, &plan.UpdatedAt) + if err != nil { + return nil, err + } + + if req.UserID != nil { + plan.UserID = sql.NullInt64{Int64: *req.UserID, Valid: true} + } + plan.ProcessForJSON() + + return plan, nil +} + +func (r *SavingsRepository) DeleteRecurringPlan(id int64) error { + _, err := r.db.Exec("DELETE FROM savings_recurring_plans WHERE id = $1", id) + return err +} + +func (r *SavingsRepository) UpdateRecurringPlan(id int64, req *model.UpdateRecurringPlanRequest) (*model.SavingsRecurringPlan, error) { + updates := []string{} + args := []interface{}{} + argNum := 1 + + if req.Effective != nil { + effective, err := time.Parse("2006-01-02", *req.Effective) + if err != nil { + return nil, fmt.Errorf("invalid effective date") + } + updates = append(updates, fmt.Sprintf("effective = $%d", argNum)) + args = append(args, effective) + argNum++ + } + if req.Amount != nil { + updates = append(updates, fmt.Sprintf("amount = $%d", argNum)) + args = append(args, *req.Amount) + argNum++ + } + if req.Day != nil { + updates = append(updates, fmt.Sprintf("day = $%d", argNum)) + args = append(args, *req.Day) + argNum++ + } + + if len(updates) == 0 { + return r.GetRecurringPlan(id) + } + + updates = append(updates, "updated_at = NOW()") + args = append(args, id) + + query := fmt.Sprintf("UPDATE savings_recurring_plans SET %s WHERE id = $%d", + joinStrings(updates, ", "), argNum) + + _, err := r.db.Exec(query, args...) + if err != nil { + return nil, err + } + + return r.GetRecurringPlan(id) +} + +func (r *SavingsRepository) GetRecurringPlan(id int64) (*model.SavingsRecurringPlan, error) { + var plan model.SavingsRecurringPlan + err := r.db.Get(&plan, "SELECT * FROM savings_recurring_plans WHERE id = $1", id) + if err != nil { + return nil, err + } + plan.ProcessForJSON() + return &plan, nil +} + +// ==================== STATS ==================== + +func (r *SavingsRepository) GetCategoryBalance(categoryID int64) (float64, error) { + var balance float64 + err := r.db.Get(&balance, ` + SELECT COALESCE( + (SELECT initial_capital FROM savings_categories WHERE id = $1), 0 + ) + COALESCE( + (SELECT SUM(CASE WHEN type = 'deposit' THEN amount ELSE -amount END) + FROM savings_transactions WHERE category_id = $1), 0 + ) + `, categoryID) + return balance, err +} + +func (r *SavingsRepository) GetStats(userID int64) (*model.SavingsStats, error) { + stats := &model.SavingsStats{ + ByCategory: []model.CategoryStats{}, + } + + // Get categories for user + categories, err := r.ListCategories(userID) + if err != nil { + return nil, err + } + + stats.CategoriesCount = len(categories) + + for _, cat := range categories { + balance, _ := r.GetCategoryBalance(cat.ID) + + catStats := model.CategoryStats{ + CategoryID: cat.ID, + CategoryName: cat.Name, + Balance: balance, + IsDeposit: cat.IsDeposit, + IsRecurring: cat.IsRecurring, + } + stats.ByCategory = append(stats.ByCategory, catStats) + stats.TotalBalance += balance + } + + // Total deposits and withdrawals + err = r.db.Get(&stats.TotalDeposits, ` + SELECT COALESCE(SUM(t.amount), 0) + FROM savings_transactions t + JOIN savings_categories c ON t.category_id = c.id + WHERE t.user_id = $1 AND t.type = 'deposit' AND c.is_account = false + `, userID) + if err != nil { + return nil, err + } + + err = r.db.Get(&stats.TotalWithdrawals, ` + SELECT COALESCE(SUM(t.amount), 0) + FROM savings_transactions t + JOIN savings_categories c ON t.category_id = c.id + WHERE t.user_id = $1 AND t.type = 'withdrawal' AND c.is_account = false + `, userID) + if err != nil { + return nil, err + } + + // Get monthly payments (only for current user, only unpaid) + monthly, details, _ := r.GetCurrentMonthlyPayments(userID) + stats.MonthlyPayments = monthly + stats.MonthlyPaymentDetails = details + + // Get overdues (all past months) + overdues, _ := r.GetOverdues(userID) + stats.Overdues = overdues + + return stats, nil +} + +// GetRecurringTotalAmount calculates the total recurring target for a category +func (r *SavingsRepository) GetRecurringTotalAmount(categoryID int64) (float64, error) { + plans, err := r.ListRecurringPlans(categoryID) + if err != nil { + return 0, err + } + + if len(plans) == 0 { + return 0, nil + } + + var total float64 + now := time.Now() + + for _, plan := range plans { + monthsActive := 0 + checkDate := plan.Effective + + for i, nextPlan := range plans { + if i == 0 { + continue + } + endDate := nextPlan.Effective + if i == len(plans)-1 { + endDate = now + } + for checkDate.Before(endDate) { + monthsActive++ + checkDate = checkDate.AddDate(0, 1, 0) + } + total += float64(monthsActive) * plan.Amount + monthsActive = 0 + } + } + + if len(plans) > 0 { + lastPlan := plans[len(plans)-1] + checkDate := lastPlan.Effective + monthsActive := 0 + for checkDate.Before(now) { + monthsActive++ + checkDate = checkDate.AddDate(0, 1, 0) + } + total += float64(monthsActive) * lastPlan.Amount + } + + return total, nil +} + +// GetCurrentMonthlyPayments returns pending recurring payments for current user this month +// GetCurrentMonthlyPayments returns pending recurring payments for current user this month +func (r *SavingsRepository) GetCurrentMonthlyPayments(userID int64) (float64, []model.MonthlyPaymentDetail, error) { + now := time.Now() + currentMonth := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC) + + categories, err := r.ListCategories(userID) + if err != nil { + return 0, nil, err + } + + var totalMonthly float64 + var details []model.MonthlyPaymentDetail + + for _, cat := range categories { + if !cat.IsRecurring || cat.IsClosed { + continue + } + + // Get plans for current user + var plans []struct { + Effective time.Time `db:"effective"` + Amount float64 `db:"amount"` + Day int `db:"day"` + } + err := r.db.Select(&plans, ` + SELECT effective, amount, day + FROM savings_recurring_plans + WHERE category_id = $1 AND user_id = $2 + ORDER BY effective ASC + `, cat.ID, userID) + if err != nil || len(plans) == 0 { + continue + } + + // Determine start month (same logic as old Savings) + // Start from category.CreatedAt month + start := time.Date(cat.CreatedAt.Year(), cat.CreatedAt.Month(), 1, 0, 0, 0, 0, time.UTC) + + // For multi categories, use earliest plan effective if earlier + if cat.IsMulti && len(plans) > 0 { + earliestEffective := time.Date(plans[0].Effective.Year(), plans[0].Effective.Month(), 1, 0, 0, 0, 0, time.UTC) + if earliestEffective.Before(start) { + start = earliestEffective + } + } + + // Build months list from start to current + var months []time.Time + for m := start; !m.After(currentMonth); m = m.AddDate(0, 1, 0) { + months = append(months, m) + } + if len(months) == 0 { + continue + } + + // Calculate required amounts per month from plans + required := make([]float64, len(months)) + days := make([]int, len(months)) + for i, ms := range months { + for _, p := range plans { + planMonth := time.Date(p.Effective.Year(), p.Effective.Month(), 1, 0, 0, 0, 0, time.UTC) + if !planMonth.After(ms) { + required[i] = p.Amount + days[i] = p.Day + if days[i] < 1 || days[i] > 28 { + days[i] = 1 + } + } + } + } + + // Remaining amounts + remaining := make([]float64, len(required)) + copy(remaining, required) + + // Get deposits for this user and category since start + var deposits []struct { + Date time.Time `db:"date"` + Amount float64 `db:"amount"` + } + r.db.Select(&deposits, ` + SELECT date, amount FROM savings_transactions + WHERE category_id = $1 AND user_id = $2 AND type = 'deposit' AND date >= $3 + ORDER BY date ASC + `, cat.ID, userID, start) + + // Helper: find month index for date + findIdx := func(d time.Time) int { + ds := time.Date(d.Year(), d.Month(), 1, 0, 0, 0, 0, time.UTC) + for i, ms := range months { + if ms.Equal(ds) { + return i + } + } + if ds.Before(months[0]) { + return -1 + } + return len(months) - 1 + } + + // Allocate deposits: first to current month, then to older months (newest first) + for _, dep := range deposits { + amt := dep.Amount + idx := findIdx(dep.Date) + if idx == -1 { + continue + } + + // Pay current month first + if remaining[idx] > 0 { + use := amt + if use > remaining[idx] { + use = remaining[idx] + } + remaining[idx] -= use + amt -= use + } + + // Then pay older months (newest to oldest) + for k := idx - 1; k >= 0 && amt > 0; k-- { + if remaining[k] <= 0 { + continue + } + use := amt + if use > remaining[k] { + use = remaining[k] + } + remaining[k] -= use + amt -= use + } + } + + // Check current month (last in list) + last := len(months) - 1 + if days[last] == 0 { + days[last] = 1 + } + dueDate := time.Date(months[last].Year(), months[last].Month(), days[last], 0, 0, 0, 0, time.UTC) + + // Only show if due date passed and not fully paid + if !dueDate.After(now) && remaining[last] > 0.01 { + details = append(details, model.MonthlyPaymentDetail{ + CategoryID: cat.ID, + CategoryName: cat.Name, + Amount: remaining[last], + Day: days[last], + }) + totalMonthly += remaining[last] + } + } + + return totalMonthly, details, nil +} + +// GetOverdues returns overdue payments using the same allocation algorithm as old Savings +func (r *SavingsRepository) GetOverdues(userID int64) ([]model.OverduePayment, error) { + now := time.Now() + currentMonth := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC) + + categories, err := r.ListCategories(userID) + if err != nil { + return nil, err + } + + var overdues []model.OverduePayment + + for _, cat := range categories { + if !cat.IsRecurring || cat.IsClosed { + continue + } + + // Get plans for current user only + var plans []struct { + Effective time.Time `db:"effective"` + Amount float64 `db:"amount"` + Day int `db:"day"` + } + err := r.db.Select(&plans, ` + SELECT effective, amount, day + FROM savings_recurring_plans + WHERE category_id = $1 AND user_id = $2 + ORDER BY effective ASC + `, cat.ID, userID) + if err != nil || len(plans) == 0 { + continue + } + + // Determine start month (same logic as old Savings) + start := time.Date(cat.CreatedAt.Year(), cat.CreatedAt.Month(), 1, 0, 0, 0, 0, time.UTC) + + // For multi categories, use earliest plan effective if earlier + if cat.IsMulti && len(plans) > 0 { + earliestEffective := time.Date(plans[0].Effective.Year(), plans[0].Effective.Month(), 1, 0, 0, 0, 0, time.UTC) + if earliestEffective.Before(start) { + start = earliestEffective + } + } + + // Build months list + var months []time.Time + for m := start; !m.After(currentMonth); m = m.AddDate(0, 1, 0) { + months = append(months, m) + } + if len(months) == 0 { + continue + } + + // Calculate required amounts per month + required := make([]float64, len(months)) + days := make([]int, len(months)) + for i, ms := range months { + for _, p := range plans { + planMonth := time.Date(p.Effective.Year(), p.Effective.Month(), 1, 0, 0, 0, 0, time.UTC) + if !planMonth.After(ms) { + required[i] = p.Amount + days[i] = p.Day + if days[i] < 1 || days[i] > 28 { + days[i] = 1 + } + } + } + } + + // Remaining amounts + remaining := make([]float64, len(required)) + copy(remaining, required) + + // Get deposits for this user and category since start + var deposits []struct { + Date time.Time `db:"date"` + Amount float64 `db:"amount"` + } + r.db.Select(&deposits, ` + SELECT date, amount FROM savings_transactions + WHERE category_id = $1 AND user_id = $2 AND type = 'deposit' AND date >= $3 + ORDER BY date ASC + `, cat.ID, userID, start) + + // Helper: find month index + findIdx := func(d time.Time) int { + ds := time.Date(d.Year(), d.Month(), 1, 0, 0, 0, 0, time.UTC) + for i, ms := range months { + if ms.Equal(ds) { + return i + } + } + if ds.Before(months[0]) { + return -1 + } + return len(months) - 1 + } + + // Allocate deposits + for _, dep := range deposits { + amt := dep.Amount + idx := findIdx(dep.Date) + if idx == -1 { + continue + } + + if remaining[idx] > 0 { + use := amt + if use > remaining[idx] { + use = remaining[idx] + } + remaining[idx] -= use + amt -= use + } + + for k := idx - 1; k >= 0 && amt > 0; k-- { + if remaining[k] <= 0 { + continue + } + use := amt + if use > remaining[k] { + use = remaining[k] + } + remaining[k] -= use + amt -= use + } + } + + // Get username + var userName string + r.db.Get(&userName, "SELECT username FROM users WHERE id = $1", userID) + + // All previous months (before current) with remaining balance are overdues + last := len(months) - 1 + for i := 0; i < last; i++ { + if remaining[i] > 0.01 { + if days[i] == 0 { + days[i] = 1 + } + dueDate := time.Date(months[i].Year(), months[i].Month(), days[i], 0, 0, 0, 0, time.UTC) + daysOverdue := int(now.Sub(dueDate).Hours() / 24) + + overdues = append(overdues, model.OverduePayment{ + CategoryID: cat.ID, + CategoryName: cat.Name, + UserID: userID, + UserName: userName, + Amount: remaining[i], + DueDay: days[i], + DaysOverdue: daysOverdue, + Month: months[i].Format("2006-01"), + }) + } + } + } + + return overdues, nil +} diff --git a/internal/repository/task.go b/internal/repository/task.go index ae33581..c56e2a8 100644 --- a/internal/repository/task.go +++ b/internal/repository/task.go @@ -21,8 +21,8 @@ func NewTaskRepository(db *sqlx.DB) *TaskRepository { func (r *TaskRepository) Create(task *model.Task) error { query := ` - INSERT INTO tasks (user_id, title, description, icon, color, due_date, priority, reminder_time) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + INSERT INTO tasks (user_id, title, description, icon, color, due_date, priority, reminder_time, is_recurring, recurrence_type, recurrence_interval, recurrence_end_date, parent_task_id) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) RETURNING id, created_at, updated_at` return r.db.QueryRow(query, @@ -34,6 +34,11 @@ func (r *TaskRepository) Create(task *model.Task) error { task.DueDate, task.Priority, task.ReminderTime, + task.IsRecurring, + task.RecurrenceType, + task.RecurrenceInterval, + task.RecurrenceEndDate, + task.ParentTaskID, ).Scan(&task.ID, &task.CreatedAt, &task.UpdatedAt) } @@ -41,13 +46,17 @@ func (r *TaskRepository) GetByID(id, userID int64) (*model.Task, error) { var task model.Task query := ` - SELECT id, user_id, title, description, icon, color, due_date, priority, reminder_time, completed_at, created_at, updated_at + SELECT id, user_id, title, description, icon, color, due_date, priority, reminder_time, completed_at, + is_recurring, recurrence_type, recurrence_interval, recurrence_end_date, parent_task_id, + created_at, updated_at FROM tasks WHERE id = $1 AND user_id = $2` err := r.db.QueryRow(query, id, userID).Scan( &task.ID, &task.UserID, &task.Title, &task.Description, &task.Icon, &task.Color, &task.DueDate, &task.Priority, - &task.ReminderTime, &task.CompletedAt, &task.CreatedAt, &task.UpdatedAt, + &task.ReminderTime, &task.CompletedAt, + &task.IsRecurring, &task.RecurrenceType, &task.RecurrenceInterval, &task.RecurrenceEndDate, &task.ParentTaskID, + &task.CreatedAt, &task.UpdatedAt, ) if err != nil { @@ -63,7 +72,9 @@ func (r *TaskRepository) GetByID(id, userID int64) (*model.Task, error) { func (r *TaskRepository) ListByUser(userID int64, completed *bool) ([]model.Task, error) { query := ` - SELECT id, user_id, title, description, icon, color, due_date, priority, reminder_time, completed_at, created_at, updated_at + SELECT id, user_id, title, description, icon, color, due_date, priority, reminder_time, completed_at, + is_recurring, recurrence_type, recurrence_interval, recurrence_end_date, parent_task_id, + created_at, updated_at FROM tasks WHERE user_id = $1` if completed != nil { @@ -88,7 +99,9 @@ func (r *TaskRepository) ListByUser(userID int64, completed *bool) ([]model.Task if err := rows.Scan( &task.ID, &task.UserID, &task.Title, &task.Description, &task.Icon, &task.Color, &task.DueDate, &task.Priority, - &task.ReminderTime, &task.CompletedAt, &task.CreatedAt, &task.UpdatedAt, + &task.ReminderTime, &task.CompletedAt, + &task.IsRecurring, &task.RecurrenceType, &task.RecurrenceInterval, &task.RecurrenceEndDate, &task.ParentTaskID, + &task.CreatedAt, &task.UpdatedAt, ); err != nil { return nil, err } @@ -104,7 +117,9 @@ func (r *TaskRepository) GetTodayTasks(userID int64) ([]model.Task, error) { today := time.Now().Format("2006-01-02") query := ` - SELECT id, user_id, title, description, icon, color, due_date, priority, reminder_time, completed_at, created_at, updated_at + SELECT id, user_id, title, description, icon, color, due_date, priority, reminder_time, completed_at, + is_recurring, recurrence_type, recurrence_interval, recurrence_end_date, parent_task_id, + created_at, updated_at FROM tasks WHERE user_id = $1 AND completed_at IS NULL AND due_date <= $2 ORDER BY priority DESC, due_date, created_at` @@ -122,7 +137,9 @@ func (r *TaskRepository) GetTodayTasks(userID int64) ([]model.Task, error) { if err := rows.Scan( &task.ID, &task.UserID, &task.Title, &task.Description, &task.Icon, &task.Color, &task.DueDate, &task.Priority, - &task.ReminderTime, &task.CompletedAt, &task.CreatedAt, &task.UpdatedAt, + &task.ReminderTime, &task.CompletedAt, + &task.IsRecurring, &task.RecurrenceType, &task.RecurrenceInterval, &task.RecurrenceEndDate, &task.ParentTaskID, + &task.CreatedAt, &task.UpdatedAt, ); err != nil { return nil, err } @@ -136,12 +153,14 @@ func (r *TaskRepository) GetTodayTasks(userID int64) ([]model.Task, error) { func (r *TaskRepository) GetTasksWithReminder(reminderTime string, date string) ([]model.Task, error) { query := ` - SELECT t.id, t.user_id, t.title, t.description, t.icon, t.color, t.due_date, t.priority, t.reminder_time, t.completed_at, t.created_at, t.updated_at + SELECT t.id, t.user_id, t.title, t.description, t.icon, t.color, t.due_date, t.priority, t.reminder_time, t.completed_at, + t.is_recurring, t.recurrence_type, t.recurrence_interval, t.recurrence_end_date, t.parent_task_id, + t.created_at, t.updated_at FROM tasks t JOIN users u ON t.user_id = u.id WHERE t.reminder_time = $1 AND t.completed_at IS NULL - AND (t.due_date IS NULL OR t.due_date >= $2) + AND t.due_date = $2 AND u.telegram_chat_id IS NOT NULL AND u.notifications_enabled = true` @@ -157,7 +176,9 @@ func (r *TaskRepository) GetTasksWithReminder(reminderTime string, date string) if err := rows.Scan( &task.ID, &task.UserID, &task.Title, &task.Description, &task.Icon, &task.Color, &task.DueDate, &task.Priority, - &task.ReminderTime, &task.CompletedAt, &task.CreatedAt, &task.UpdatedAt, + &task.ReminderTime, &task.CompletedAt, + &task.IsRecurring, &task.RecurrenceType, &task.RecurrenceInterval, &task.RecurrenceEndDate, &task.ParentTaskID, + &task.CreatedAt, &task.UpdatedAt, ); err != nil { return nil, err } @@ -171,8 +192,10 @@ func (r *TaskRepository) GetTasksWithReminder(reminderTime string, date string) func (r *TaskRepository) Update(task *model.Task) error { query := ` UPDATE tasks - SET title = $2, description = $3, icon = $4, color = $5, due_date = $6, priority = $7, reminder_time = $8, updated_at = CURRENT_TIMESTAMP - WHERE id = $1 AND user_id = $9 + SET title = $2, description = $3, icon = $4, color = $5, due_date = $6, priority = $7, reminder_time = $8, + is_recurring = $9, recurrence_type = $10, recurrence_interval = $11, recurrence_end_date = $12, + updated_at = CURRENT_TIMESTAMP + WHERE id = $1 AND user_id = $13 RETURNING updated_at` return r.db.QueryRow(query, @@ -184,6 +207,10 @@ func (r *TaskRepository) Update(task *model.Task) error { task.DueDate, task.Priority, task.ReminderTime, + task.IsRecurring, + task.RecurrenceType, + task.RecurrenceInterval, + task.RecurrenceEndDate, task.UserID, ).Scan(&task.UpdatedAt) } diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index bead38b..6924812 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -13,20 +13,22 @@ import ( ) type Scheduler struct { - cron *cron.Cron - bot *bot.Bot - userRepo *repository.UserRepository - taskRepo *repository.TaskRepository - habitRepo *repository.HabitRepository + cron *cron.Cron + bot *bot.Bot + userRepo *repository.UserRepository + taskRepo *repository.TaskRepository + habitRepo *repository.HabitRepository + freezeRepo *repository.HabitFreezeRepository } -func New(b *bot.Bot, userRepo *repository.UserRepository, taskRepo *repository.TaskRepository, habitRepo *repository.HabitRepository) *Scheduler { +func New(b *bot.Bot, userRepo *repository.UserRepository, taskRepo *repository.TaskRepository, habitRepo *repository.HabitRepository, freezeRepo *repository.HabitFreezeRepository) *Scheduler { return &Scheduler{ - cron: cron.New(), - bot: b, - userRepo: userRepo, - taskRepo: taskRepo, - habitRepo: habitRepo, + cron: cron.New(), + bot: b, + userRepo: userRepo, + taskRepo: taskRepo, + habitRepo: habitRepo, + freezeRepo: freezeRepo, } } @@ -117,19 +119,13 @@ func (s *Scheduler) sendMorningBriefing(userID, chatID int64, loc *time.Location return } - // Filter habits for today + // Filter habits for today (excluding frozen ones) weekday := int(time.Now().In(loc).Weekday()) + today := time.Now().In(loc).Truncate(24 * time.Hour) var todayHabits int for _, habit := range habits { - if habit.Frequency == "daily" { + if s.shouldShowHabitToday(habit, userID, weekday, today) { todayHabits++ - } else { - for _, day := range habit.TargetDays { - if day == weekday { - todayHabits++ - break - } - } } } @@ -185,25 +181,14 @@ func (s *Scheduler) sendEveningSummary(userID, chatID int64, loc *time.Location) return } - // Filter and count today's habits + // Filter and count today's habits (excluding frozen ones) weekday := int(time.Now().In(loc).Weekday()) + today := time.Now().In(loc).Truncate(24 * time.Hour) var completedHabits, incompleteHabits int var incompleteHabitNames []string for _, habit := range habits { - isToday := false - if habit.Frequency == "daily" { - isToday = true - } else { - for _, day := range habit.TargetDays { - if day == weekday { - isToday = true - break - } - } - } - - if isToday { + if s.shouldShowHabitToday(habit, userID, weekday, today) { completed, _ := s.habitRepo.IsHabitCompletedToday(habit.ID, userID) if completed { completedHabits++ @@ -313,17 +298,37 @@ func (s *Scheduler) checkHabitReminders(userID, chatID int64, currentTime string return } + today := time.Now().Truncate(24 * time.Hour) + for _, habit := range habits { if habit.UserID != userID { continue } + // Check if habit is frozen today + frozen, err := s.freezeRepo.IsHabitFrozenOnDate(habit.ID, today) + if err != nil { + log.Printf("Scheduler: error checking freeze for habit %d: %v", habit.ID, err) + continue + } + if frozen { + continue + } + // Check if already completed today completed, _ := s.habitRepo.IsHabitCompletedToday(habit.ID, userID) if completed { continue } + // For interval habits, check if it should be shown today + if (habit.Frequency == "interval" || habit.Frequency == "custom") && habit.TargetCount > 0 { + shouldShow, err := s.habitRepo.ShouldShowIntervalHabitToday(habit.ID, userID, habit.TargetCount, habit.StartDate) + if err != nil || !shouldShow { + continue + } + } + text := fmt.Sprintf("⏰ Напоминание о привычке:\n\n%s %s", habit.Icon, habit.Name) if habit.Description != "" { text += fmt.Sprintf("\n%s", habit.Description) @@ -339,3 +344,40 @@ func (s *Scheduler) checkHabitReminders(userID, chatID int64, currentTime string s.bot.SendMessageWithKeyboard(chatID, text, &keyboard) } } + +// shouldShowHabitToday checks if a habit should be shown today based on its frequency and freeze status +func (s *Scheduler) shouldShowHabitToday(habit model.Habit, userID int64, weekday int, today time.Time) bool { + // Check if habit is frozen today + frozen, err := s.freezeRepo.IsHabitFrozenOnDate(habit.ID, today) + if err != nil { + log.Printf("Scheduler: error checking freeze for habit %d: %v", habit.ID, err) + return false + } + if frozen { + return false + } + + if habit.Frequency == "daily" { + return true + } + + if habit.Frequency == "weekly" { + for _, day := range habit.TargetDays { + if day == weekday { + return true + } + } + return false + } + + // For interval habits + if (habit.Frequency == "interval" || habit.Frequency == "custom") && habit.TargetCount > 0 { + shouldShow, err := s.habitRepo.ShouldShowIntervalHabitToday(habit.ID, userID, habit.TargetCount, habit.StartDate) + if err != nil { + return false + } + return shouldShow + } + + return true +} diff --git a/internal/service/habit.go b/internal/service/habit.go index 488ae60..1654f57 100644 --- a/internal/service/habit.go +++ b/internal/service/habit.go @@ -2,18 +2,26 @@ package service import ( "database/sql" + "errors" "time" "github.com/daniil/homelab-api/internal/model" "github.com/daniil/homelab-api/internal/repository" ) +var ErrFutureDate = errors.New("cannot log habit for future date") +var ErrAlreadyLogged = errors.New("habit already logged for this date") + type HabitService struct { - habitRepo *repository.HabitRepository + habitRepo *repository.HabitRepository + freezeRepo *repository.HabitFreezeRepository } -func NewHabitService(habitRepo *repository.HabitRepository) *HabitService { - return &HabitService{habitRepo: habitRepo} +func NewHabitService(habitRepo *repository.HabitRepository, freezeRepo *repository.HabitFreezeRepository) *HabitService { + return &HabitService{ + habitRepo: habitRepo, + freezeRepo: freezeRepo, + } } func (s *HabitService) Create(userID int64, req *model.CreateHabitRequest) (*model.Habit, error) { @@ -32,6 +40,17 @@ func (s *HabitService) Create(userID int64, req *model.CreateHabitRequest) (*mod habit.ReminderTime = sql.NullString{String: *req.ReminderTime, Valid: true} } + // Handle start_date - default to today if not provided + if req.StartDate != nil && *req.StartDate != "" { + parsed, err := time.Parse("2006-01-02", *req.StartDate) + if err == nil { + habit.StartDate = sql.NullTime{Time: parsed, Valid: true} + } + } else { + // Default to today + habit.StartDate = sql.NullTime{Time: time.Now().Truncate(24 * time.Hour), Valid: true} + } + if err := s.habitRepo.Create(habit); err != nil { return nil, err } @@ -89,6 +108,16 @@ func (s *HabitService) Update(id, userID int64, req *model.UpdateHabitRequest) ( habit.ReminderTime = sql.NullString{String: *req.ReminderTime, Valid: true} } } + if req.StartDate != nil { + if *req.StartDate == "" { + habit.StartDate = sql.NullTime{Valid: false} + } else { + parsed, err := time.Parse("2006-01-02", *req.StartDate) + if err == nil { + habit.StartDate = sql.NullTime{Time: parsed, Valid: true} + } + } + } if req.IsArchived != nil { habit.IsArchived = *req.IsArchived } @@ -111,13 +140,29 @@ func (s *HabitService) Log(habitID, userID int64, req *model.LogHabitRequest) (* return nil, err } - date := time.Now().Truncate(24 * time.Hour) + today := time.Now().Truncate(24 * time.Hour) + date := today + if req.Date != "" { parsed, err := time.Parse("2006-01-02", req.Date) if err != nil { return nil, err } - date = parsed + date = parsed.Truncate(24 * time.Hour) + } + + // Validate: cannot log for future date + if date.After(today) { + return nil, ErrFutureDate + } + + // Check if already logged for this date + alreadyLogged, err := s.habitRepo.IsHabitCompletedOnDate(habitID, userID, date) + if err != nil { + return nil, err + } + if alreadyLogged { + return nil, ErrAlreadyLogged } log := &model.HabitLog{ @@ -160,11 +205,20 @@ func (s *HabitService) DeleteLog(logID, userID int64) error { func (s *HabitService) GetHabitStats(habitID, userID int64) (*model.HabitStats, error) { // Verify habit exists and belongs to user - if _, err := s.habitRepo.GetByID(habitID, userID); err != nil { + habit, err := s.habitRepo.GetByID(habitID, userID) + if err != nil { return nil, err } - return s.habitRepo.GetStats(habitID, userID) + stats, err := s.habitRepo.GetStats(habitID, userID) + if err != nil { + return nil, err + } + + // Recalculate completion percentage with frozen days excluded + stats.CompletionPct = s.calculateCompletionPctWithFreezes(habit, stats.TotalLogs) + + return stats, nil } func (s *HabitService) GetOverallStats(userID int64) (*model.OverallStats, error) { @@ -190,6 +244,75 @@ func (s *HabitService) GetOverallStats(userID int64) (*model.OverallStats, error }, nil } +// calculateCompletionPctWithFreezes calculates completion % excluding frozen days +func (s *HabitService) calculateCompletionPctWithFreezes(habit *model.Habit, totalLogs int) float64 { + if totalLogs == 0 { + return 0 + } + + // Use start_date if set, otherwise use created_at + var startDate time.Time + if habit.StartDate.Valid { + startDate = habit.StartDate.Time.Truncate(24 * time.Hour) + } else { + startDate = habit.CreatedAt.Truncate(24 * time.Hour) + } + today := time.Now().Truncate(24 * time.Hour) + + // Get frozen days count for this habit + frozenDays, err := s.freezeRepo.CountFrozenDaysInRange(habit.ID, startDate, today) + if err != nil { + frozenDays = 0 + } + + expectedCount := 0 + + // For interval habits, calculate expected differently + if (habit.Frequency == "interval" || habit.Frequency == "custom") && habit.TargetCount > 0 { + // Expected = (days since start - frozen days) / interval + 1 + totalDays := int(today.Sub(startDate).Hours()/24) + 1 - frozenDays + if totalDays <= 0 { + return 100 + } + expectedCount = (totalDays / habit.TargetCount) + 1 + } else { + for d := startDate; !d.After(today); d = d.AddDate(0, 0, 1) { + // Check if this day is frozen + frozen, _ := s.freezeRepo.IsHabitFrozenOnDate(habit.ID, d) + if frozen { + continue + } + + if habit.Frequency == "daily" { + expectedCount++ + } else if habit.Frequency == "weekly" && len(habit.TargetDays) > 0 { + weekday := int(d.Weekday()) + if weekday == 0 { + weekday = 7 + } + for _, td := range habit.TargetDays { + if td == weekday { + expectedCount++ + break + } + } + } else { + expectedCount++ + } + } + } + + if expectedCount == 0 { + return 100 + } + + pct := float64(totalLogs) / float64(expectedCount) * 100 + if pct > 100 { + pct = 100 + } + return pct +} + func defaultString(val, def string) string { if val == "" { return def diff --git a/internal/service/task.go b/internal/service/task.go index 1831ab9..89aa21f 100644 --- a/internal/service/task.go +++ b/internal/service/task.go @@ -18,12 +18,14 @@ func NewTaskService(taskRepo *repository.TaskRepository) *TaskService { func (s *TaskService) Create(userID int64, req *model.CreateTaskRequest) (*model.Task, error) { task := &model.Task{ - UserID: userID, - Title: req.Title, - Description: req.Description, - Icon: defaultString(req.Icon, "📋"), - Color: defaultString(req.Color, "#6B7280"), - Priority: req.Priority, + UserID: userID, + Title: req.Title, + Description: req.Description, + Icon: defaultString(req.Icon, "📋"), + Color: defaultString(req.Color, "#6B7280"), + Priority: req.Priority, + IsRecurring: req.IsRecurring, + RecurrenceInterval: defaultInt(req.RecurrenceInterval, 1), } if req.DueDate != nil && *req.DueDate != "" { @@ -37,6 +39,17 @@ func (s *TaskService) Create(userID int64, req *model.CreateTaskRequest) (*model task.ReminderTime = sql.NullString{String: *req.ReminderTime, Valid: true} } + if req.RecurrenceType != nil && *req.RecurrenceType != "" { + task.RecurrenceType = sql.NullString{String: *req.RecurrenceType, Valid: true} + } + + if req.RecurrenceEndDate != nil && *req.RecurrenceEndDate != "" { + parsed, err := time.Parse("2006-01-02", *req.RecurrenceEndDate) + if err == nil { + task.RecurrenceEndDate = sql.NullTime{Time: parsed, Valid: true} + } + } + if err := s.taskRepo.Create(task); err != nil { return nil, err } @@ -110,6 +123,31 @@ func (s *TaskService) Update(id, userID int64, req *model.UpdateTaskRequest) (*m } } + // Handle recurring fields + if req.IsRecurring != nil { + task.IsRecurring = *req.IsRecurring + } + if req.RecurrenceType != nil { + if *req.RecurrenceType == "" { + task.RecurrenceType = sql.NullString{Valid: false} + } else { + task.RecurrenceType = sql.NullString{String: *req.RecurrenceType, Valid: true} + } + } + if req.RecurrenceInterval != nil { + task.RecurrenceInterval = *req.RecurrenceInterval + } + if req.RecurrenceEndDate != nil { + if *req.RecurrenceEndDate == "" { + task.RecurrenceEndDate = sql.NullTime{Valid: false} + } else { + parsed, err := time.Parse("2006-01-02", *req.RecurrenceEndDate) + if err == nil { + task.RecurrenceEndDate = sql.NullTime{Time: parsed, Valid: true} + } + } + } + if err := s.taskRepo.Update(task); err != nil { return nil, err } @@ -123,15 +161,78 @@ func (s *TaskService) Delete(id, userID int64) error { } func (s *TaskService) Complete(id, userID int64) (*model.Task, error) { + // First, get the task to check if it's recurring + task, err := s.taskRepo.GetByID(id, userID) + if err != nil { + return nil, err + } + + // Complete the current task if err := s.taskRepo.Complete(id, userID); err != nil { return nil, err } + + // If task is recurring, create the next occurrence + if task.IsRecurring && task.RecurrenceType.Valid && task.DueDate.Valid { + s.createNextRecurrence(task) + } + return s.taskRepo.GetByID(id, userID) } +func (s *TaskService) createNextRecurrence(task *model.Task) { + // Calculate next due date based on recurrence type + var nextDueDate time.Time + interval := task.RecurrenceInterval + if interval < 1 { + interval = 1 + } + + currentDue := task.DueDate.Time + + switch task.RecurrenceType.String { + case "daily": + nextDueDate = currentDue.AddDate(0, 0, interval) + case "weekly": + nextDueDate = currentDue.AddDate(0, 0, 7*interval) + case "monthly": + nextDueDate = currentDue.AddDate(0, interval, 0) + case "custom": + nextDueDate = currentDue.AddDate(0, 0, interval) + default: + return // Unknown recurrence type, don't create + } + + // Check if next date is past the end date + if task.RecurrenceEndDate.Valid && nextDueDate.After(task.RecurrenceEndDate.Time) { + return // Don't create task past end date + } + + // Create the next task + nextTask := &model.Task{ + UserID: task.UserID, + Title: task.Title, + Description: task.Description, + Icon: task.Icon, + Color: task.Color, + Priority: task.Priority, + DueDate: sql.NullTime{Time: nextDueDate, Valid: true}, + ReminderTime: task.ReminderTime, + IsRecurring: true, + RecurrenceType: task.RecurrenceType, + RecurrenceInterval: task.RecurrenceInterval, + RecurrenceEndDate: task.RecurrenceEndDate, + ParentTaskID: sql.NullInt64{Int64: task.ID, Valid: true}, + } + + // Silently create, ignore errors + s.taskRepo.Create(nextTask) +} + func (s *TaskService) Uncomplete(id, userID int64) (*model.Task, error) { if err := s.taskRepo.Uncomplete(id, userID); err != nil { return nil, err } return s.taskRepo.GetByID(id, userID) } +