package database import ( "Portifolio/internal/model" "database/sql" "fmt" ) func InsertCategory(db *sql.DB, rc model.RevenueCategory) error { if err := rc.Validate(false); err != nil { return fmt.Errorf("failed to insert: %w", err) } company, err := GetCompanyByID(db, rc.CompanyID) if err != nil { return fmt.Errorf("failed to check company: %w", err) } if company == nil { return fmt.Errorf("company %d not found", rc.CompanyID) } if rc.ParentID != nil && *rc.ParentID != 0 { parent, err := GetCategoryByID(db, rc.CompanyID, *rc.ParentID) if err != nil { return fmt.Errorf("failed to check parent category: %w", err) } if parent == nil { return fmt.Errorf("parent category %d not found", *rc.ParentID) } } var parentID sql.NullInt64 if rc.ParentID != nil && *rc.ParentID != 0 { parentID = sql.NullInt64{Int64: int64(*rc.ParentID), Valid: true} } _, err = db.Exec( `INSERT INTO category (company_id, parent_id, name) VALUES (?, ?, ?) ON CONFLICT(company_id, name) DO UPDATE SET parent_id=excluded.parent_id`, rc.CompanyID, parentID, rc.Name, ) if err != nil { return fmt.Errorf("upsert category: %w", err) } err = db.QueryRow( `SELECT id FROM category WHERE company_id = ? AND name = ?`, rc.CompanyID, rc.Name, ).Scan(&rc.ID) if err != nil { return fmt.Errorf("select category id: %w", err) } return nil } func GetCategoriesByCompanyID(db *sql.DB, companyID int) ([]string, error) { rows, err := db.Query("SELECT name FROM category WHERE company_id = ?", companyID) if err != nil { return []string{}, err } defer rows.Close() var list []string for rows.Next() { var Name string err := rows.Scan(&Name) if err != nil { return []string{}, err } list = append(list, Name) } return list, nil } func GetCategoryByID(db *sql.DB, companyID int, ID int) (*model.RevenueCategory, error) { var rc model.RevenueCategory err := db.QueryRow( `SELECT id, company_id, parent_id, name FROM category WHERE company_id = ? AND id = ?`, companyID, ID, ).Scan(&rc.ID, &rc.CompanyID, &rc.ParentID, &rc.Name) if err == sql.ErrNoRows { return &rc, fmt.Errorf("category %q not found for company %d", ID, companyID) } if err != nil { return &rc, fmt.Errorf("get category by name: %w", err) } return &rc, nil } func GetCategoryByName(db *sql.DB, companyID int, name string) (model.RevenueCategory, error) { var rc model.RevenueCategory err := db.QueryRow( `SELECT id, company_id, parent_id, name FROM category WHERE company_id = ? AND name = ?`, companyID, name, ).Scan(&rc.ID, &rc.CompanyID, &rc.ParentID, &rc.Name) if err == sql.ErrNoRows { return rc, fmt.Errorf("category %q not found for company %d", name, companyID) } if err != nil { return rc, fmt.Errorf("get category by name: %w", err) } return rc, nil }