166 lines
4.5 KiB
Go
166 lines
4.5 KiB
Go
// Package repository provides data access layer for the application
|
|
package repository
|
|
|
|
import (
|
|
"errors"
|
|
|
|
"accounting-app/internal/models"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// User repository errors
|
|
var (
|
|
ErrUserNotFound = errors.New("user not found")
|
|
ErrUserEmailExists = errors.New("email already exists")
|
|
ErrOAuthAccountExists = errors.New("oauth account already linked")
|
|
)
|
|
|
|
// UserRepository handles database operations for users
|
|
// Feature: api-interface-optimization
|
|
// Validates: Requirements 12, 13
|
|
type UserRepository struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
// NewUserRepository creates a new UserRepository instance
|
|
func NewUserRepository(db *gorm.DB) *UserRepository {
|
|
return &UserRepository{db: db}
|
|
}
|
|
|
|
// Create creates a new user in the database
|
|
func (r *UserRepository) Create(user *models.User) error {
|
|
// Check if email already exists
|
|
var count int64
|
|
if err := r.db.Model(&models.User{}).Where("email = ?", user.Email).Count(&count).Error; err != nil {
|
|
return err
|
|
}
|
|
if count > 0 {
|
|
return ErrUserEmailExists
|
|
}
|
|
|
|
return r.db.Create(user).Error
|
|
}
|
|
|
|
// GetByID retrieves a user by ID
|
|
func (r *UserRepository) GetByID(id uint) (*models.User, error) {
|
|
var user models.User
|
|
if err := r.db.First(&user, id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrUserNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
// GetByEmail retrieves a user by email
|
|
func (r *UserRepository) GetByEmail(email string) (*models.User, error) {
|
|
var user models.User
|
|
if err := r.db.Where("email = ?", email).First(&user).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrUserNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
|
|
// Update updates a user in the database
|
|
func (r *UserRepository) Update(user *models.User) error {
|
|
return r.db.Save(user).Error
|
|
}
|
|
|
|
// Delete soft deletes a user
|
|
func (r *UserRepository) Delete(id uint) error {
|
|
result := r.db.Delete(&models.User{}, id)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return ErrUserNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetByOAuthProvider retrieves a user by OAuth provider and provider ID
|
|
func (r *UserRepository) GetByOAuthProvider(provider, providerID string) (*models.User, error) {
|
|
var oauth models.OAuthAccount
|
|
if err := r.db.Where("provider = ? AND provider_id = ?", provider, providerID).First(&oauth).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrUserNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
var user models.User
|
|
if err := r.db.First(&user, oauth.UserID).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrUserNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
// CreateOAuthAccount creates a new OAuth account linked to a user
|
|
func (r *UserRepository) CreateOAuthAccount(oauth *models.OAuthAccount) error {
|
|
// Check if OAuth account already exists
|
|
var count int64
|
|
if err := r.db.Model(&models.OAuthAccount{}).
|
|
Where("provider = ? AND provider_id = ?", oauth.Provider, oauth.ProviderID).
|
|
Count(&count).Error; err != nil {
|
|
return err
|
|
}
|
|
if count > 0 {
|
|
return ErrOAuthAccountExists
|
|
}
|
|
|
|
return r.db.Create(oauth).Error
|
|
}
|
|
|
|
// GetOAuthAccounts retrieves all OAuth accounts for a user
|
|
func (r *UserRepository) GetOAuthAccounts(userID uint) ([]models.OAuthAccount, error) {
|
|
var accounts []models.OAuthAccount
|
|
if err := r.db.Where("user_id = ?", userID).Find(&accounts).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return accounts, nil
|
|
}
|
|
|
|
// UpdateOAuthToken updates the access token for an OAuth account
|
|
func (r *UserRepository) UpdateOAuthToken(provider, providerID, accessToken string) error {
|
|
result := r.db.Model(&models.OAuthAccount{}).
|
|
Where("provider = ? AND provider_id = ?", provider, providerID).
|
|
Update("access_token", accessToken)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return ErrUserNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeleteOAuthAccount removes an OAuth account link
|
|
func (r *UserRepository) DeleteOAuthAccount(userID uint, provider string) error {
|
|
result := r.db.Where("user_id = ? AND provider = ?", userID, provider).Delete(&models.OAuthAccount{})
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return ErrUserNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// EmailExists checks if an email is already registered
|
|
func (r *UserRepository) EmailExists(email string) (bool, error) {
|
|
var count int64
|
|
if err := r.db.Model(&models.User{}).Where("email = ?", email).Count(&count).Error; err != nil {
|
|
return false, err
|
|
}
|
|
return count > 0, nil
|
|
}
|