Files
2026-03-25 19:07:57 +01:00

106 lines
2.8 KiB
Go

package database
import (
"Portifolio/internal/model"
"database/sql"
"fmt"
)
func InsertCategory(db *sql.DB, rc model.RevenueCategory) error {
if err := rc.Validate(false); err != nil {
return fmt.Errorf("failed to insert: %w", err)
}
company, err := GetCompanyByID(db, rc.CompanyID)
if err != nil {
return fmt.Errorf("failed to check company: %w", err)
}
if company == nil {
return fmt.Errorf("company %d not found", rc.CompanyID)
}
if rc.ParentID != nil && *rc.ParentID != 0 {
parent, err := GetCategoryByID(db, rc.CompanyID, *rc.ParentID)
if err != nil {
return fmt.Errorf("failed to check parent category: %w", err)
}
if parent == nil {
return fmt.Errorf("parent category %d not found", *rc.ParentID)
}
}
var parentID sql.NullInt64
if rc.ParentID != nil && *rc.ParentID != 0 {
parentID = sql.NullInt64{Int64: int64(*rc.ParentID), Valid: true}
}
_, err = db.Exec(
`INSERT INTO category (company_id, parent_id, name) VALUES (?, ?, ?)
ON CONFLICT(company_id, name) DO UPDATE SET parent_id=excluded.parent_id`,
rc.CompanyID, parentID, rc.Name,
)
if err != nil {
return fmt.Errorf("upsert category: %w", err)
}
err = db.QueryRow(
`SELECT id FROM category WHERE company_id = ? AND name = ?`,
rc.CompanyID, rc.Name,
).Scan(&rc.ID)
if err != nil {
return fmt.Errorf("select category id: %w", err)
}
return nil
}
func GetCategoriesByCompanyID(db *sql.DB, companyID int) ([]string, error) {
rows, err := db.Query("SELECT name FROM category WHERE company_id = ?", companyID)
if err != nil {
return []string{}, err
}
defer rows.Close()
var list []string
for rows.Next() {
var Name string
err := rows.Scan(&Name)
if err != nil {
return []string{}, err
}
list = append(list, Name)
}
return list, nil
}
func GetCategoryByID(db *sql.DB, companyID int, ID int) (*model.RevenueCategory, error) {
var rc model.RevenueCategory
err := db.QueryRow(
`SELECT id, company_id, parent_id, name FROM category WHERE company_id = ? AND id = ?`,
companyID, ID,
).Scan(&rc.ID, &rc.CompanyID, &rc.ParentID, &rc.Name)
if err == sql.ErrNoRows {
return &rc, fmt.Errorf("category %q not found for company %d", ID, companyID)
}
if err != nil {
return &rc, fmt.Errorf("get category by name: %w", err)
}
return &rc, nil
}
func GetCategoryByName(db *sql.DB, companyID int, name string) (model.RevenueCategory, error) {
var rc model.RevenueCategory
err := db.QueryRow(
`SELECT id, company_id, parent_id, name FROM category WHERE company_id = ? AND name = ?`,
companyID, name,
).Scan(&rc.ID, &rc.CompanyID, &rc.ParentID, &rc.Name)
if err == sql.ErrNoRows {
return rc, fmt.Errorf("category %q not found for company %d", name, companyID)
}
if err != nil {
return rc, fmt.Errorf("get category by name: %w", err)
}
return rc, nil
}