package repository import ( "errors" "fmt" "accounting-app/internal/models" "gorm.io/gorm" ) // Category repository errors var ( ErrCategoryNotFound = errors.New("category not found") ErrCategoryInUse = errors.New("category is in use and cannot be deleted") ErrCategoryHasChildren = errors.New("category has children and cannot be deleted") ) // CategoryRepository handles database operations for categories type CategoryRepository struct { db *gorm.DB } // NewCategoryRepository creates a new CategoryRepository instance func NewCategoryRepository(db *gorm.DB) *CategoryRepository { return &CategoryRepository{db: db} } // Create creates a new category in the database func (r *CategoryRepository) Create(category *models.Category) error { if err := r.db.Create(category).Error; err != nil { return fmt.Errorf("failed to create category: %w", err) } return nil } // GetByID retrieves a category by its ID func (r *CategoryRepository) GetByID(userID uint, id uint) (*models.Category, error) { var category models.Category if err := r.db.Where("user_id = ?", userID).First(&category, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrCategoryNotFound } return nil, fmt.Errorf("failed to get category: %w", err) } return &category, nil } // GetAll retrieves all categories for a user func (r *CategoryRepository) GetAll(userID uint) ([]models.Category, error) { var categories []models.Category if err := r.db.Where("user_id = ?", userID).Order("sort_order ASC, created_at ASC").Find(&categories).Error; err != nil { return nil, fmt.Errorf("failed to get categories: %w", err) } return categories, nil } // GetByType retrieves all categories of a specific type (income or expense) for a user func (r *CategoryRepository) GetByType(userID uint, categoryType models.CategoryType) ([]models.Category, error) { var categories []models.Category if err := r.db.Where("user_id = ? AND type = ?", userID, categoryType).Order("sort_order ASC, created_at ASC").Find(&categories).Error; err != nil { return nil, fmt.Errorf("failed to get categories by type: %w", err) } return categories, nil } // GetRootCategories retrieves all categories without a parent (top-level categories) for a user func (r *CategoryRepository) GetRootCategories(userID uint) ([]models.Category, error) { var categories []models.Category if err := r.db.Where("user_id = ? AND parent_id IS NULL", userID).Order("sort_order ASC, created_at ASC").Find(&categories).Error; err != nil { return nil, fmt.Errorf("failed to get root categories: %w", err) } return categories, nil } // GetChildren retrieves all child categories of a given parent category func (r *CategoryRepository) GetChildren(userID uint, parentID uint) ([]models.Category, error) { var categories []models.Category if err := r.db.Where("user_id = ? AND parent_id = ?", userID, parentID).Order("sort_order ASC, created_at ASC").Find(&categories).Error; err != nil { return nil, fmt.Errorf("failed to get child categories: %w", err) } return categories, nil } // GetWithChildren retrieves a category with its children preloaded func (r *CategoryRepository) GetWithChildren(userID uint, id uint) (*models.Category, error) { var category models.Category if err := r.db.Preload("Children", func(db *gorm.DB) *gorm.DB { return db.Order("sort_order ASC, created_at ASC") }).Where("user_id = ?", userID).First(&category, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrCategoryNotFound } return nil, fmt.Errorf("failed to get category with children: %w", err) } return &category, nil } // GetWithParent retrieves a category with its parent preloaded func (r *CategoryRepository) GetWithParent(userID uint, id uint) (*models.Category, error) { var category models.Category if err := r.db.Preload("Parent").Where("user_id = ?", userID).First(&category, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrCategoryNotFound } return nil, fmt.Errorf("failed to get category with parent: %w", err) } return &category, nil } // Update updates an existing category in the database func (r *CategoryRepository) Update(category *models.Category) error { // First check if the category exists var existing models.Category if err := r.db.Where("user_id = ?", category.UserID).First(&existing, category.ID).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrCategoryNotFound } return fmt.Errorf("failed to check category existence: %w", err) } // Update the category if err := r.db.Save(category).Error; err != nil { return fmt.Errorf("failed to update category: %w", err) } return nil } // Delete deletes a category by its ID func (r *CategoryRepository) Delete(userID uint, id uint) error { // First check if the category exists var category models.Category if err := r.db.Where("user_id = ?", userID).First(&category, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrCategoryNotFound } return fmt.Errorf("failed to check category existence: %w", err) } // Check if there are any child categories var childCount int64 if err := r.db.Model(&models.Category{}).Where("parent_id = ?", id).Count(&childCount).Error; err != nil { return fmt.Errorf("failed to check child categories: %w", err) } if childCount > 0 { return ErrCategoryHasChildren } // Check if there are any transactions associated with this category var transactionCount int64 if err := r.db.Model(&models.Transaction{}).Where("category_id = ?", id).Count(&transactionCount).Error; err != nil { return fmt.Errorf("failed to check category transactions: %w", err) } if transactionCount > 0 { return ErrCategoryInUse } // Check if there are any budgets associated with this category var budgetCount int64 if err := r.db.Model(&models.Budget{}).Where("category_id = ?", id).Count(&budgetCount).Error; err != nil { return fmt.Errorf("failed to check category budgets: %w", err) } if budgetCount > 0 { return ErrCategoryInUse } // Check if there are any recurring transactions associated with this category var recurringCount int64 if err := r.db.Model(&models.RecurringTransaction{}).Where("category_id = ?", id).Count(&recurringCount).Error; err != nil { return fmt.Errorf("failed to check category recurring transactions: %w", err) } if recurringCount > 0 { return ErrCategoryInUse } // Delete the category (hard delete since Category doesn't have DeletedAt) if err := r.db.Delete(&category).Error; err != nil { return fmt.Errorf("failed to delete category: %w", err) } return nil } // ExistsByID checks if a category with the given ID exists func (r *CategoryRepository) ExistsByID(userID uint, id uint) (bool, error) { var count int64 if err := r.db.Model(&models.Category{}).Where("user_id = ? AND id = ?", userID, id).Count(&count).Error; err != nil { return false, fmt.Errorf("failed to check category existence: %w", err) } return count > 0, nil } // ExistsByName checks if a category with the given name exists for a user func (r *CategoryRepository) ExistsByName(userID uint, name string) (bool, error) { var count int64 if err := r.db.Model(&models.Category{}).Where("user_id = ? AND name = ?", userID, name).Count(&count).Error; err != nil { return false, fmt.Errorf("failed to check category name existence: %w", err) } return count > 0, nil } // ExistsByNameAndType checks if a category with the given name and type exists for a user func (r *CategoryRepository) ExistsByNameAndType(userID uint, name string, categoryType models.CategoryType) (bool, error) { var count int64 if err := r.db.Model(&models.Category{}).Where("user_id = ? AND name = ? AND type = ?", userID, name, categoryType).Count(&count).Error; err != nil { return false, fmt.Errorf("failed to check category name and type existence: %w", err) } return count > 0, nil } // ExistsByNameExcludingID checks if a category with the given name exists, excluding a specific ID, for a user func (r *CategoryRepository) ExistsByNameExcludingID(userID uint, name string, excludeID uint) (bool, error) { var count int64 if err := r.db.Model(&models.Category{}).Where("user_id = ? AND name = ? AND id != ?", userID, name, excludeID).Count(&count).Error; err != nil { return false, fmt.Errorf("failed to check category name existence: %w", err) } return count > 0, nil } // GetRootCategoriesByType retrieves all root categories of a specific type for a user func (r *CategoryRepository) GetRootCategoriesByType(userID uint, categoryType models.CategoryType) ([]models.Category, error) { var categories []models.Category if err := r.db.Where("user_id = ? AND parent_id IS NULL AND type = ?", userID, categoryType).Order("sort_order ASC, created_at ASC").Find(&categories).Error; err != nil { return nil, fmt.Errorf("failed to get root categories by type: %w", err) } return categories, nil } // GetAllWithChildren retrieves all categories with their children preloaded for a user func (r *CategoryRepository) GetAllWithChildren(userID uint) ([]models.Category, error) { var categories []models.Category if err := r.db.Preload("Children", func(db *gorm.DB) *gorm.DB { return db.Order("sort_order ASC, created_at ASC") }).Where("user_id = ? AND parent_id IS NULL", userID).Order("sort_order ASC, created_at ASC").Find(&categories).Error; err != nil { return nil, fmt.Errorf("failed to get categories with children: %w", err) } return categories, nil } // CountByType returns the count of categories by type for a user func (r *CategoryRepository) CountByType(userID uint, categoryType models.CategoryType) (int64, error) { var count int64 if err := r.db.Model(&models.Category{}).Where("user_id = ? AND type = ?", userID, categoryType).Count(&count).Error; err != nil { return 0, fmt.Errorf("failed to count categories by type: %w", err) } return count, nil } // GetByName retrieves a category by its name for a user func (r *CategoryRepository) GetByName(userID uint, name string) (*models.Category, error) { var category models.Category if err := r.db.Where("user_id = ? AND name = ?", userID, name).First(&category).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrCategoryNotFound } return nil, fmt.Errorf("failed to get category by name: %w", err) } return &category, nil }