114 lines
2.7 KiB
Go
114 lines
2.7 KiB
Go
package repository
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"errors"
|
|
"time"
|
|
|
|
"github.com/daniil/homelab-api/internal/model"
|
|
"github.com/jmoiron/sqlx"
|
|
)
|
|
|
|
var ErrTokenNotFound = errors.New("token not found")
|
|
var ErrTokenExpired = errors.New("token expired")
|
|
var ErrTokenUsed = errors.New("token already used")
|
|
|
|
type EmailTokenRepository struct {
|
|
db *sqlx.DB
|
|
}
|
|
|
|
func NewEmailTokenRepository(db *sqlx.DB) *EmailTokenRepository {
|
|
return &EmailTokenRepository{db: db}
|
|
}
|
|
|
|
func (r *EmailTokenRepository) Create(userID int64, tokenType string, expiry time.Duration) (*model.EmailToken, error) {
|
|
token, err := generateSecureToken(32)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
emailToken := &model.EmailToken{
|
|
UserID: userID,
|
|
Token: token,
|
|
Type: tokenType,
|
|
ExpiresAt: time.Now().Add(expiry),
|
|
}
|
|
|
|
query := `
|
|
INSERT INTO email_tokens (user_id, token, type, expires_at)
|
|
VALUES ($1, $2, $3, $4)
|
|
RETURNING id, created_at`
|
|
|
|
err = r.db.QueryRow(query, emailToken.UserID, emailToken.Token, emailToken.Type, emailToken.ExpiresAt).
|
|
Scan(&emailToken.ID, &emailToken.CreatedAt)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return emailToken, nil
|
|
}
|
|
|
|
func (r *EmailTokenRepository) GetByToken(token string) (*model.EmailToken, error) {
|
|
var emailToken model.EmailToken
|
|
query := `SELECT id, user_id, token, type, expires_at, used_at, created_at FROM email_tokens WHERE token = $1`
|
|
|
|
if err := r.db.Get(&emailToken, query, token); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrTokenNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
return &emailToken, nil
|
|
}
|
|
|
|
func (r *EmailTokenRepository) Validate(token, tokenType string) (*model.EmailToken, error) {
|
|
emailToken, err := r.GetByToken(token)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if emailToken.Type != tokenType {
|
|
return nil, ErrTokenNotFound
|
|
}
|
|
|
|
if emailToken.UsedAt != nil {
|
|
return nil, ErrTokenUsed
|
|
}
|
|
|
|
if time.Now().After(emailToken.ExpiresAt) {
|
|
return nil, ErrTokenExpired
|
|
}
|
|
|
|
return emailToken, nil
|
|
}
|
|
|
|
func (r *EmailTokenRepository) MarkUsed(id int64) error {
|
|
query := `UPDATE email_tokens SET used_at = CURRENT_TIMESTAMP WHERE id = $1`
|
|
_, err := r.db.Exec(query, id)
|
|
return err
|
|
}
|
|
|
|
func (r *EmailTokenRepository) DeleteExpired() error {
|
|
query := `DELETE FROM email_tokens WHERE expires_at < CURRENT_TIMESTAMP OR used_at IS NOT NULL`
|
|
_, err := r.db.Exec(query)
|
|
return err
|
|
}
|
|
|
|
func (r *EmailTokenRepository) DeleteByUserAndType(userID int64, tokenType string) error {
|
|
query := `DELETE FROM email_tokens WHERE user_id = $1 AND type = $2 AND used_at IS NULL`
|
|
_, err := r.db.Exec(query, userID, tokenType)
|
|
return err
|
|
}
|
|
|
|
func generateSecureToken(length int) (string, error) {
|
|
bytes := make([]byte, length)
|
|
if _, err := rand.Read(bytes); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(bytes), nil
|
|
}
|