diff --git a/Portifolio b/Portifolio index 24d44bb..38855cc 100755 Binary files a/Portifolio and b/Portifolio differ diff --git a/app.db b/app.db index cd1eb13..bc7426d 100644 Binary files a/app.db and b/app.db differ diff --git a/internal/database/portfolio.go b/internal/database/portfolio.go index b0aa07a..31ddfe3 100644 --- a/internal/database/portfolio.go +++ b/internal/database/portfolio.go @@ -43,9 +43,9 @@ func GetTrades(db *sql.DB) ([]model.Trade, error) { } func GetPositions(db *sql.DB) ([]model.Position, error) { - rows, err := db.Query("SELECT company_id, symbol, shares, weight, CostBasis, currency_id, currency_code from position") + rows, err := db.Query("SELECT company_id, symbol, shares, weight, cost_basis, currency_id, currency_code from position") if err != nil { - return nil, err + return []model.Position{}, err } defer rows.Close() @@ -55,13 +55,13 @@ func GetPositions(db *sql.DB) ([]model.Position, error) { err := rows.Scan(&t.CompanyID, &t.Symbol, &t.Shares, &t.Weight, &t.CostBasis, &t.CurrencyID, &t.CurrencyCode) if err != nil { - return nil, err + return []model.Position{}, err } positions = append(positions, t) } if err = rows.Err(); err != nil { - return nil, err + return []model.Position{}, err } return positions, nil diff --git a/internal/handlers/portfolio.go b/internal/handlers/portfolio.go index 6b90c60..1dfa9aa 100644 --- a/internal/handlers/portfolio.go +++ b/internal/handlers/portfolio.go @@ -81,15 +81,15 @@ func GetTradeListHandler(db *sql.DB) http.HandlerFunc { func GetPositionListHandler(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - tradeList, err := database.GetPositions(db) + posList, err := database.GetPositions(db) if err != nil { - http.Error(w, "failed to fetch trades", http.StatusInternalServerError) + http.Error(w, fmt.Sprintf("failed to fetch postiton: %s", err), http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(tradeList); err != nil { - http.Error(w, "failed to encode trades", http.StatusInternalServerError) + if err := json.NewEncoder(w).Encode(map[string]any{"List": posList}); err != nil { + http.Error(w, fmt.Sprintf("failed to encode positions: %s", err), http.StatusInternalServerError) return } } diff --git a/internal/model/portifolio.go b/internal/model/portifolio.go index 2a54bf3..78be8a5 100644 --- a/internal/model/portifolio.go +++ b/internal/model/portifolio.go @@ -54,7 +54,7 @@ type Trade struct { func (r *AddTradeRequest) Validate() error { if r.Symbol == "" { - return errors.New("empty SYmbol string") + return errors.New("empty Symbol string") } if r.Shares <= 0 { return errors.New("shares must be a positive integer") diff --git a/internal/service/portfolio.go b/internal/service/portfolio.go index 845286a..572abfd 100644 --- a/internal/service/portfolio.go +++ b/internal/service/portfolio.go @@ -16,7 +16,7 @@ func UpdatePositionByTradeList(db *sql.DB) error { fmt.Printf("Failed to get the trades from db: %s", err) } - var TradeSum map[string]model.Position + TradeSum := make(map[string]model.Position) for _, trade := range trades { if trade.Type == model.Buy { diff --git a/main.go b/main.go index a76d0fb..cc1844f 100644 --- a/main.go +++ b/main.go @@ -17,6 +17,17 @@ import ( var db *sql.DB +func corsMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf("Incoming request: %s %s from %s", r.Method, r.URL.Path, r.Header.Get("Origin")) + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + next.ServeHTTP(w, r) + }) +} + func main() { var err error db, err = sql.Open("sqlite3", "./app.db?_foreign_keys=on") @@ -32,31 +43,33 @@ func main() { database.InitDB(db) fmt.Println("Connected to SQLite database") - http.HandleFunc("/health", handlers.HealthHandler(db)) + mux := http.NewServeMux() + + mux.HandleFunc("/health", handlers.HealthHandler(db)) //Trades - http.HandleFunc("POST /trade/add", handlers.AddTradeHandler(db)) - http.HandleFunc("GET /trade/list", handlers.GetTradeListHandler(db)) - http.HandleFunc("GET /positions/list", handlers.GetTradeListHandler(db)) + mux.HandleFunc("POST /trade/add", handlers.AddTradeHandler(db)) + mux.HandleFunc("GET /trade/list", handlers.GetTradeListHandler(db)) + mux.HandleFunc("GET /positions/list", handlers.GetPositionListHandler(db)) // Company - http.HandleFunc("POST /company/add", handlers.AddCompanyHandler(db)) - http.HandleFunc("GET /company/list", handlers.GetCompaniesHandler(db)) - http.HandleFunc("GET /company/revenue/categories", handlers.GetCompanyRevenueCategories(db)) + mux.HandleFunc("POST /company/add", handlers.AddCompanyHandler(db)) + mux.HandleFunc("GET /company/list", handlers.GetCompaniesHandler(db)) + mux.HandleFunc("GET /company/revenue/categories", handlers.GetCompanyRevenueCategories(db)) // Currency - http.HandleFunc("GET /currency/list", handlers.GetCurrenciesHandler(db)) - http.HandleFunc("POST /currency/add", handlers.AddCurrencyHandler(db)) + mux.HandleFunc("GET /currency/list", handlers.GetCurrenciesHandler(db)) + mux.HandleFunc("POST /currency/add", handlers.AddCurrencyHandler(db)) // Revenue - http.HandleFunc("POST /add/revenue/entry", handlers.AddRevenueEntryHandler(db)) - http.HandleFunc("POST /api/v1/revenue/add", handlers.AddRevenueEntryHandler(db)) + mux.HandleFunc("POST /add/revenue/entry", handlers.AddRevenueEntryHandler(db)) + mux.HandleFunc("POST /api/v1/revenue/add", handlers.AddRevenueEntryHandler(db)) //http.HandleFunc("GET /revenue/report", handlers.GetRevenueReportHandler(db)) fmt.Println("Server running on :8080") go func() { - log.Fatal(http.ListenAndServe(":8080", nil)) + log.Fatal(http.ListenAndServe(":8080", corsMiddleware(mux))) }() runShell(db)