diff --git a/Portifolio b/Portifolio index 641f399..24d44bb 100755 Binary files a/Portifolio and b/Portifolio differ diff --git a/app.db b/app.db index 7f9fbe3..cd1eb13 100644 Binary files a/app.db and b/app.db differ diff --git a/internal/database/company.go b/internal/database/company.go index c5bd023..3c3ac60 100644 --- a/internal/database/company.go +++ b/internal/database/company.go @@ -8,6 +8,22 @@ import ( _ "github.com/mattn/go-sqlite3" ) +func GetCompanyBySymbol(db *sql.DB, symbol string) (*model.Company, error) { + var c model.Company + err := db.QueryRow( + `SELECT id, symbol, shares_outstanding, price, currency_id FROM companies WHERE symbol = ?`, + symbol, + ).Scan(&c.ID, &c.Symbol, &c.SharesOutstanding, &c.Price, &c.CurrencyID) + + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("query company: %w", err) + } + return &c, nil +} + func GetCompanyByID(db *sql.DB, id int) (*model.Company, error) { var c model.Company err := db.QueryRow( diff --git a/internal/database/portfolio.go b/internal/database/portfolio.go index af9c5e5..b0aa07a 100644 --- a/internal/database/portfolio.go +++ b/internal/database/portfolio.go @@ -9,7 +9,7 @@ import ( ) func GetTrades(db *sql.DB) ([]model.Trade, error) { - rows, err := db.Query("SELECT company_id, symbol, currency_id, currency_code, shares, product, type, price, traded_at FROM trades") + rows, err := db.Query("SELECT symbol, currency_code, shares, product, type, price, traded_at FROM trades") if err != nil { return nil, err } @@ -20,7 +20,7 @@ func GetTrades(db *sql.DB) ([]model.Trade, error) { var typeInt int var t model.Trade - err := rows.Scan(&t.CompanyID, &t.Symbol, &t.CompanyID, &t.CurrencyCode, &t.Shares, &t.Product, &typeInt, &t.Price, &t.Date) + err := rows.Scan(&t.Symbol, &t.CurrencyCode, &t.Shares, &t.Product, &typeInt, &t.Price, &t.Date) if err != nil { return nil, err } @@ -69,10 +69,8 @@ func GetPositions(db *sql.DB) ([]model.Position, error) { func InsertTrade(db *sql.DB, trade model.Trade) error { _, err := db.Exec( - "INSERT INTO trades (company_id, symbol, currency_id, currency_code, shares, product, type, price, traded_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", - trade.CompanyID, + "INSERT INTO trades (symbol, currency_code, shares, product, type, price, traded_at) VALUES (?, ?, ?, ?, ?, ?, ?)", trade.Symbol, - trade.CurrencyID, trade.CurrencyCode, trade.Shares, trade.Product, @@ -82,3 +80,56 @@ func InsertTrade(db *sql.DB, trade model.Trade) error { ) return err } + +func UpdatePositions(db *sql.DB, positions []model.Position) error { + // Complete overwrite of the db positions + _, err := db.Exec("DELETE FROM position") + if err != nil { + return fmt.Errorf("failed to clear positions: %s", err) + } + + for _, p := range positions { + // Resolve company_id if missing + if p.CompanyID == 0 { + company, err := GetCompanyBySymbol(db, p.Symbol) + if err != nil { + return fmt.Errorf("could not find company %s: %s", p.Symbol, err) + } + p.CompanyID = company.ID + } + + // Resolve currency_id if missing + if p.CurrencyID == 0 { + currency, err := GetCurrencyByCode(db, p.CurrencyCode) + if err != nil { + return fmt.Errorf("could not find currency %s: %s", p.CurrencyCode, err) + } + p.CurrencyID = currency.ID + } + + _, err := db.Exec(` + INSERT INTO position (company_id, symbol, currency_id, currency_code, shares, weight, cost_basis) + VALUES (?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(company_id) DO UPDATE SET + symbol = excluded.symbol, + currency_id = excluded.currency_id, + currency_code = excluded.currency_code, + shares = excluded.shares, + weight = excluded.weight, + cost_basis = excluded.cost_basis + `, + p.CompanyID, + p.Symbol, + p.CurrencyID, + p.CurrencyCode, + p.Shares, + p.Weight, + p.CostBasis, + ) + if err != nil { + return fmt.Errorf("failed to upsert position %s: %s", p.Symbol, err) + } + } + + return nil +} diff --git a/internal/handlers/portfolio.go b/internal/handlers/portfolio.go index 4279369..6b90c60 100644 --- a/internal/handlers/portfolio.go +++ b/internal/handlers/portfolio.go @@ -3,6 +3,7 @@ package handlers import ( "Portifolio/internal/database" "Portifolio/internal/model" + "Portifolio/internal/service" "database/sql" "encoding/json" "fmt" @@ -32,15 +33,8 @@ func AddTradeHandler(db *sql.DB) http.HandlerFunc { return } - // check if company is in the db. - company, err := database.GetCompanyByID(db, req.TickerId) - if err != nil { - http.Error(w, fmt.Sprintf("failed to find currency: %s", err), http.StatusInternalServerError) - return - } - trade := model.Trade{ - Symbol: company.Symbol, + Symbol: req.Symbol, Shares: req.Shares, Product: model.TradeProduct(req.Product), Type: model.TradeType(req.Type), @@ -49,7 +43,23 @@ func AddTradeHandler(db *sql.DB) http.HandlerFunc { Date: req.Date, } - database.InsertTrade(db, trade) + err = database.InsertTrade(db, trade) + if err != nil { + http.Error(w, fmt.Sprintf("failed to insert trade into db: %s", err), http.StatusInternalServerError) + return + } + + err = service.UpdatePositionByTradeList(db) + update := true + if err != nil { + update = false + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{"success": true, "position update": update}); err != nil { + http.Error(w, fmt.Sprintf("failed to encode trades: %s", err), http.StatusInternalServerError) + return + } } } diff --git a/internal/model/portifolio.go b/internal/model/portifolio.go index 371586f..2a54bf3 100644 --- a/internal/model/portifolio.go +++ b/internal/model/portifolio.go @@ -32,20 +32,8 @@ const ( Sell TradeType = false ) -type Trade struct { - CompanyID int - Symbol string - CurrencyID int - CurrencyCode string - Shares int - Product TradeProduct - Type TradeType - Price float64 - Date time.Time -} - type AddTradeRequest struct { - TickerId int `json:"ticker_id"` + Symbol string `json:"symbol"` Shares int `json:"shares"` Product int `json:"product"` Type bool `json:"type"` @@ -54,9 +42,19 @@ type AddTradeRequest struct { Date time.Time `json:"date"` } +type Trade struct { + Symbol string + CurrencyCode string + Shares int + Product TradeProduct + Type TradeType + Price float64 + Date time.Time +} + func (r *AddTradeRequest) Validate() error { - if r.TickerId <= 0 { - return errors.New("ticker id must be a positive integer") + if r.Symbol == "" { + return errors.New("empty SYmbol string") } if r.Shares <= 0 { return errors.New("shares must be a positive integer") @@ -67,7 +65,7 @@ func (r *AddTradeRequest) Validate() error { if r.Price <= 0 { return errors.New("price must be a positive number") } - if r.CurrencyCode != "" { + if r.CurrencyCode == "" { return errors.New("currency is required") } if r.Date.IsZero() { diff --git a/internal/service/portfolio.go b/internal/service/portfolio.go new file mode 100644 index 0000000..845286a --- /dev/null +++ b/internal/service/portfolio.go @@ -0,0 +1,51 @@ +package service + +import ( + "Portifolio/internal/database" + "Portifolio/internal/model" + "database/sql" + "fmt" + + _ "github.com/mattn/go-sqlite3" +) + +func UpdatePositionByTradeList(db *sql.DB) error { + + trades, err := database.GetTrades(db) + if err != nil { + fmt.Printf("Failed to get the trades from db: %s", err) + } + + var TradeSum map[string]model.Position + + for _, trade := range trades { + if trade.Type == model.Buy { + TradeSum[trade.Symbol] = model.Position{ + Symbol: trade.Symbol, + CurrencyCode: trade.CurrencyCode, + CostBasis: TradeSum[trade.Symbol].CostBasis + trade.Price, + Shares: TradeSum[trade.Symbol].Shares + trade.Shares, + } + } else { + TradeSum[trade.Symbol] = model.Position{ + Symbol: trade.Symbol, + CurrencyCode: trade.CurrencyCode, + CostBasis: TradeSum[trade.Symbol].CostBasis - trade.Price, + Shares: TradeSum[trade.Symbol].Shares - trade.Shares, + } + } + + } + + var NewPositinos []model.Position + for _, pos := range TradeSum { + NewPositinos = append(NewPositinos, pos) + } + + err = database.UpdatePositions(db, NewPositinos) + if err != nil { + return fmt.Errorf("Failed to insert the new postions number into db: %s", err) + } + + return nil +}