diff --git a/Portifolio b/Portifolio index c4e0c7d..8f24077 100755 Binary files a/Portifolio and b/Portifolio differ diff --git a/app.db b/app.db index a8a8782..b6bd6bd 100644 Binary files a/app.db and b/app.db differ diff --git a/internal/database/category.go b/internal/database/category.go index 323ecc4..6dad7bb 100644 --- a/internal/database/category.go +++ b/internal/database/category.go @@ -52,3 +52,54 @@ func InsertCategory(db *sql.DB, rc model.RevenueCategory) error { } return nil } + +func GetCategoriesByCompanyID(db *sql.DB, companyID int) ([]string, error) { + rows, err := db.Query("SELECT name FROM category WHERE company_id = ?", companyID) + if err != nil { + return []string{}, err + } + defer rows.Close() + + var list []string + for rows.Next() { + var Name string + err := rows.Scan(&Name) + if err != nil { + return []string{}, err + } + + list = append(list, Name) + } + return list, nil +} + +func GetCategoryByID(db *sql.DB, companyID int, ID int) (*model.RevenueCategory, error) { + var rc model.RevenueCategory + err := db.QueryRow( + `SELECT id, company_id, parent_id, name FROM category WHERE company_id = ? AND id = ?`, + companyID, ID, + ).Scan(&rc.ID, &rc.CompanyID, &rc.ParentID, &rc.Name) + if err == sql.ErrNoRows { + return &rc, fmt.Errorf("category %q not found for company %d", ID, companyID) + } + if err != nil { + return &rc, fmt.Errorf("get category by name: %w", err) + } + return &rc, nil +} + +func GetCategoryByName(db *sql.DB, companyID int, name string) (model.RevenueCategory, error) { + + var rc model.RevenueCategory + err := db.QueryRow( + `SELECT id, company_id, parent_id, name FROM category WHERE company_id = ? AND name = ?`, + companyID, name, + ).Scan(&rc.ID, &rc.CompanyID, &rc.ParentID, &rc.Name) + if err == sql.ErrNoRows { + return rc, fmt.Errorf("category %q not found for company %d", name, companyID) + } + if err != nil { + return rc, fmt.Errorf("get category by name: %w", err) + } + return rc, nil +} diff --git a/internal/database/revenue.go b/internal/database/revenue.go index f033ac0..87302a3 100644 --- a/internal/database/revenue.go +++ b/internal/database/revenue.go @@ -9,37 +9,6 @@ import ( _ "github.com/mattn/go-sqlite3" ) -func GetCategoryByID(db *sql.DB, companyID int, ID int) (*model.RevenueCategory, error) { - var rc model.RevenueCategory - err := db.QueryRow( - `SELECT id, company_id, parent_id, name FROM category WHERE company_id = ? AND id = ?`, - companyID, ID, - ).Scan(&rc.ID, &rc.CompanyID, &rc.ParentID, &rc.Name) - if err == sql.ErrNoRows { - return &rc, fmt.Errorf("category %q not found for company %d", ID, companyID) - } - if err != nil { - return &rc, fmt.Errorf("get category by name: %w", err) - } - return &rc, nil -} - -func GetCategoryByName(db *sql.DB, companyID int, name string) (model.RevenueCategory, error) { - - var rc model.RevenueCategory - err := db.QueryRow( - `SELECT id, company_id, parent_id, name FROM category WHERE company_id = ? AND name = ?`, - companyID, name, - ).Scan(&rc.ID, &rc.CompanyID, &rc.ParentID, &rc.Name) - if err == sql.ErrNoRows { - return rc, fmt.Errorf("category %q not found for company %d", name, companyID) - } - if err != nil { - return rc, fmt.Errorf("get category by name: %w", err) - } - return rc, nil -} - func GetPeriodByID(db *sql.DB, periodID int) (model.Period, error) { var p model.Period var start, end string diff --git a/internal/handlers/revenue.go b/internal/handlers/revenue.go index 6e62adb..7eb590f 100644 --- a/internal/handlers/revenue.go +++ b/internal/handlers/revenue.go @@ -1,10 +1,12 @@ package handlers import ( + "Portifolio/internal/database" "Portifolio/internal/model" "Portifolio/internal/service" "database/sql" "encoding/json" + "fmt" "net/http" _ "github.com/mattn/go-sqlite3" @@ -69,3 +71,25 @@ func GetRevenueReportHandler(db *sql.DB) http.HandlerFunc { } } */ + +func GetCompanyRevenueCategories(db *sql.DB) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var input struct { + CompanyID int `json:"company_id"` + } + if err := json.NewDecoder(r.Body).Decode(&input); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + + catlist, err := database.GetCategoriesByCompanyID(db, input.CompanyID) + if err != nil { + http.Error(w, fmt.Sprintf("Could not find categories by that id:%s", err), http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(map[string][]string{"list": catlist}) + } +} diff --git a/main.go b/main.go index 85efcbd..730bb72 100644 --- a/main.go +++ b/main.go @@ -41,14 +41,16 @@ func main() { http.HandleFunc("GET /positions/list", handlers.GetTradeListHandler(db)) // Company - http.HandleFunc("POST /add/company", handlers.AddCompanyHandler(db)) - http.HandleFunc("GET /companies", handlers.GetCompaniesHandler(db)) + http.HandleFunc("POST /company/add", handlers.AddCompanyHandler(db)) + http.HandleFunc("GET /company/list", handlers.GetCompaniesHandler(db)) + http.HandleFunc("GET /company/revenue/categories", handlers.GetCompanyRevenueCategories(db)) // Currency http.HandleFunc("GET /currencies", handlers.GetCurrenciesHandler(db)) // Revenue http.HandleFunc("POST /add/revenue/entry", handlers.AddRevenueEntryHandler(db)) + //http.HandleFunc("GET /revenue/report", handlers.GetRevenueReportHandler(db)) fmt.Println("Server running on :8080")