diff --git a/app.db b/app.db index bc7426d..46399b2 100644 Binary files a/app.db and b/app.db differ diff --git a/internal/database/main.go b/internal/database/main.go index 06f51a3..33f6798 100644 --- a/internal/database/main.go +++ b/internal/database/main.go @@ -17,14 +17,26 @@ func InitDB(db *sql.DB) { ); CREATE TABLE IF NOT EXISTS trades ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - symbol TEXT NOT NULL, - currency_code TEXT NOT NULL, - shares INTEGER NOT NULL, - product INTEGER NOT NULL CHECK(product IN (0, 1, 2, 3)), - type INTEGER NOT NULL CHECK(type IN (0, 1)), - price REAL NOT NULL, - traded_at DATETIME NOT NULL + id INTEGER PRIMARY KEY AUTOINCREMENT, + symbol TEXT NOT NULL, + currency_code TEXT NOT NULL, + shares INTEGER NOT NULL, + product INTEGER NOT NULL CHECK(product IN (0, 1, 2, 3, 4)), -- added 4 for BondTrade + type INTEGER NOT NULL CHECK(type IN (0, 1)), -- Buy=0, Sell=1 only; Dividend has its own table + price REAL NOT NULL, + traded_at DATETIME NOT NULL + ); + + CREATE TABLE IF NOT EXISTS dividends ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + symbol TEXT NOT NULL, + currency_code TEXT NOT NULL, + product INTEGER NOT NULL CHECK(product IN (0, 1, 2, 3, 4)), + value REAL NOT NULL, + tax_amount REAL NOT NULL DEFAULT 0, + tax_rate REAL NOT NULL DEFAULT 0, + net_value REAL NOT NULL, + payment_date DATETIME NOT NULL ); CREATE TABLE IF NOT EXISTS position ( @@ -69,6 +81,29 @@ func InitDB(db *sql.DB) { UNIQUE(company_id, name) ); + -- parent table + CREATE TABLE closed_positions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + symbol TEXT NOT NULL, + currency_code TEXT NOT NULL, + product TEXT NOT NULL, + open_time DATETIME NOT NULL, + realized_gain REAL, + tax_amount REAL, + holding_days INTEGER + ); + + -- child table, one row per close lot + CREATE TABLE close_entries ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + closed_position_id INTEGER NOT NULL REFERENCES closed_positions(id), + shares INTEGER NOT NULL, + in_price REAL NOT NULL, + out_price REAL NOT NULL, + gain_price REAL NOT NULL, + close_time DATETIME NOT NULL + ); + CREATE TABLE IF NOT EXISTS revenue_entries ( id INTEGER PRIMARY KEY AUTOINCREMENT, company_id INTEGER NOT NULL, diff --git a/internal/database/portfolio.go b/internal/database/portfolio.go index 31ddfe3..0d6d354 100644 --- a/internal/database/portfolio.go +++ b/internal/database/portfolio.go @@ -25,14 +25,7 @@ func GetTrades(db *sql.DB) ([]model.Trade, error) { return nil, err } - switch typeInt { - case 0: - t.Type = model.TradeType(false) - case 1: - t.Type = model.TradeType(true) - default: - return nil, fmt.Errorf("failed to convert given Type int to bool of trade type.") - } + t.Type = model.TradeType(typeInt) trades = append(trades, t) } @@ -81,6 +74,21 @@ func InsertTrade(db *sql.DB, trade model.Trade) error { return err } +func InsertDividend(db *sql.DB, div model.Dividend) error { + _, err := db.Exec( + "INSERT INTO trades (symbol, currency_code, shares, product, value, tax_amount, tax_rate, net_value, payment_date) VALUES (?, ?, ?, ?, ?, ?, ?)", + div.Symbol, + div.CurrencyCode, + div.Product, + div.Value, + div.TaxAmount, + div.TaxRate, + div.NetValue, + div.PaymentDate, + ) + return err +} + func UpdatePositions(db *sql.DB, positions []model.Position) error { // Complete overwrite of the db positions _, err := db.Exec("DELETE FROM position") diff --git a/internal/handlers/portfolio.go b/internal/handlers/portfolio.go index 9f9594b..8fb1e85 100644 --- a/internal/handlers/portfolio.go +++ b/internal/handlers/portfolio.go @@ -20,45 +20,57 @@ func AddTradeHandler(db *sql.DB) http.HandlerFunc { return } - err := req.Validate() - if err != nil { - http.Error(w, fmt.Sprintf("failed to validate trade: %s", err), http.StatusInternalServerError) + if err := req.Validate(); err != nil { + http.Error(w, fmt.Sprintf("failed to validate trade: %s", err), http.StatusBadRequest) return } - // check if currency is in the db. currency, err := database.GetCurrencyByCode(db, req.CurrencyCode) if err != nil { http.Error(w, fmt.Sprintf("failed to find currency: %s", err), http.StatusInternalServerError) return } - trade := model.Trade{ - Symbol: req.Symbol, - Shares: req.Shares, - Product: model.TradeProduct(req.Product), - Type: model.TradeType(req.Type), - Price: req.Price, - CurrencyCode: currency.Code, - Date: req.Date, - } + switch model.TradeType(req.Type) { + case model.DividendType: + dividend, err := req.ToDividend() + if err != nil { + http.Error(w, fmt.Sprintf("failed to build dividend: %s", err), http.StatusBadRequest) + return + } + dividend.CurrencyCode = currency.Code - err = database.InsertTrade(db, trade) - if err != nil { - http.Error(w, fmt.Sprintf("failed to insert trade into db: %s", err), http.StatusInternalServerError) - return - } + if err := database.InsertDividend(db, dividend); err != nil { + http.Error(w, fmt.Sprintf("failed to insert dividend: %s", err), http.StatusInternalServerError) + return + } - err = service.UpdatePositionByTradeList(db) - update := true - if err != nil { - update = false - } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{"success": true}) - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{"success": true, "position update": update}); err != nil { - http.Error(w, fmt.Sprintf("failed to encode trades: %s", err), http.StatusInternalServerError) - return + case model.BuyType, model.SellType: + trade, err := req.ToTrade() + if err != nil { + http.Error(w, fmt.Sprintf("failed to build trade: %s", err), http.StatusBadRequest) + return + } + trade.CurrencyCode = currency.Code + + if err := database.InsertTrade(db, trade); err != nil { + http.Error(w, fmt.Sprintf("failed to insert trade: %s", err), http.StatusInternalServerError) + return + } + + update := true + if err := service.UpdatePositionByTradeList(db); err != nil { + update = false + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{"success": true, "position_update": update}) + + default: + http.Error(w, fmt.Sprintf("unknown trade type: %d", req.Type), http.StatusBadRequest) } } } diff --git a/internal/model/portifolio.go b/internal/model/portifolio.go index 78be8a5..fa40695 100644 --- a/internal/model/portifolio.go +++ b/internal/model/portifolio.go @@ -2,9 +2,40 @@ package model import ( "errors" + "fmt" "time" ) +type Dividend struct { + Symbol string + CurrencyCode string + Product TradeProduct + Value float64 + PaymentDate time.Time + TaxAmount float64 + TaxRate float64 + NetValue float64 +} + +type CloseEntry struct { + Shares int + InPrice float64 + OutPrice float64 + GainPrice float64 + CloseTime time.Time +} + +type ClosedPosition struct { + Symbol string + CurrencyCode string + Product TradeProduct + Closes []CloseEntry // each close carries its own prices + share count + OpenTime time.Time + RealizedGain float64 + TaxAmount float64 + HoldingDays int +} + type Position struct { CompanyID int Symbol string @@ -25,21 +56,28 @@ const ( BondTrade ) -type TradeType bool +type TradeType int const ( - Buy TradeType = true - Sell TradeType = false + BuyType TradeType = iota // 0 + SellType // 1 + DividendType // 2 ) type AddTradeRequest struct { Symbol string `json:"symbol"` Shares int `json:"shares"` Product int `json:"product"` - Type bool `json:"type"` + Type int `json:"type"` // was bool, now int Price float64 `json:"price"` CurrencyCode string `json:"currency_code"` Date time.Time `json:"date"` + + // Dividend-specific fields (only populated when Type == 2) + TaxAmount float64 `json:"tax_amount,omitempty"` + TaxRate float64 `json:"tax_rate,omitempty"` + NetValue float64 `json:"net_value,omitempty"` + PaymentDate time.Time `json:"payment_date,omitempty"` } type Trade struct { @@ -77,8 +115,42 @@ func (r *AddTradeRequest) Validate() error { return nil } +func (r AddTradeRequest) ToDividend() (Dividend, error) { + if TradeType(r.Type) != DividendType { + return Dividend{}, fmt.Errorf("trade type is not a dividend") + } + return Dividend{ + Symbol: r.Symbol, + CurrencyCode: r.CurrencyCode, + Product: TradeProduct(r.Product), + Value: r.Price, // gross value + PaymentDate: r.PaymentDate, + TaxAmount: r.TaxAmount, + TaxRate: r.TaxRate, + NetValue: r.NetValue, + }, nil +} + +func (r AddTradeRequest) ToTrade() (Trade, error) { + t := TradeType(r.Type) + if t != BuyType && t != SellType { + return Trade{}, fmt.Errorf("trade type is not buy or sell") + } + return Trade{ + Symbol: r.Symbol, + CurrencyCode: r.CurrencyCode, + Shares: r.Shares, + Product: TradeProduct(r.Product), + Type: t, + Price: r.Price, + Date: r.Date, + }, nil +} + // for now trades and none stock position will not be supported. type Portifolio struct { Positions []Position Trades []Trade + closed []ClosedPosition + Dividends []Dividend } diff --git a/internal/service/portfolio.go b/internal/service/portfolio.go index 572abfd..bd55c2f 100644 --- a/internal/service/portfolio.go +++ b/internal/service/portfolio.go @@ -19,7 +19,7 @@ func UpdatePositionByTradeList(db *sql.DB) error { TradeSum := make(map[string]model.Position) for _, trade := range trades { - if trade.Type == model.Buy { + if trade.Type == model.BuyType { TradeSum[trade.Symbol] = model.Position{ Symbol: trade.Symbol, CurrencyCode: trade.CurrencyCode, diff --git a/main.go b/main.go index 69c2277..16be328 100644 --- a/main.go +++ b/main.go @@ -65,11 +65,17 @@ func main() { mux.HandleFunc("POST /trade/add", handlers.AddTradeHandler(db)) mux.HandleFunc("GET /trade/list", handlers.GetTradeListHandler(db)) + mux.HandleFunc("GET /trade/search", handlers.GetTradeListHandler(db)) // new + //Positions mux.HandleFunc("GET /positions/list", handlers.GetPositionListHandler(db)) - + mux.HandleFunc("GET /positions/closed/list", handlers.GetPositionListHandler(db)) // new + mux.HandleFunc("GET /positions/closed/search", handlers.GetTradeListHandler(db)) // new + // Company mux.HandleFunc("POST /company/add", handlers.AddCompanyHandler(db)) mux.HandleFunc("GET /company/list", handlers.GetCompaniesHandler(db)) mux.HandleFunc("GET /company/revenue/categories", handlers.GetCompanyRevenueCategories(db)) + mux.HandleFunc("POST /company/S-O/add", handlers.GetCompaniesHandler(db)) // new + mux.HandleFunc("GET /company/S-O/list", handlers.GetCompaniesHandler(db)) // new mux.HandleFunc("GET /currency/list", handlers.GetCurrenciesHandler(db)) mux.HandleFunc("POST /currency/add", handlers.AddCurrencyHandler(db))