package database import ( "Portifolio/internal/model" "database/sql" "fmt" _ "github.com/mattn/go-sqlite3" ) func GetTrades(db *sql.DB) ([]model.Trade, error) { rows, err := db.Query("SELECT symbol, currency_code, shares, product, type, price, traded_at FROM trades") if err != nil { return nil, err } defer rows.Close() var trades []model.Trade for rows.Next() { var typeInt int var t model.Trade err := rows.Scan(&t.Symbol, &t.CurrencyCode, &t.Shares, &t.Product, &typeInt, &t.Price, &t.Date) if err != nil { return nil, err } switch typeInt { case 0: t.Type = model.TradeType(false) case 1: t.Type = model.TradeType(true) default: return nil, fmt.Errorf("failed to convert given Type int to bool of trade type.") } trades = append(trades, t) } if err = rows.Err(); err != nil { return nil, err } return trades, nil } func GetPositions(db *sql.DB) ([]model.Position, error) { rows, err := db.Query("SELECT company_id, symbol, shares, weight, cost_basis, currency_id, currency_code from position") if err != nil { return []model.Position{}, err } defer rows.Close() var positions []model.Position for rows.Next() { var t model.Position err := rows.Scan(&t.CompanyID, &t.Symbol, &t.Shares, &t.Weight, &t.CostBasis, &t.CurrencyID, &t.CurrencyCode) if err != nil { return []model.Position{}, err } positions = append(positions, t) } if err = rows.Err(); err != nil { return []model.Position{}, err } return positions, nil } func InsertTrade(db *sql.DB, trade model.Trade) error { _, err := db.Exec( "INSERT INTO trades (symbol, currency_code, shares, product, type, price, traded_at) VALUES (?, ?, ?, ?, ?, ?, ?)", trade.Symbol, trade.CurrencyCode, trade.Shares, trade.Product, trade.Type, trade.Price, trade.Date, ) return err } func UpdatePositions(db *sql.DB, positions []model.Position) error { // Complete overwrite of the db positions _, err := db.Exec("DELETE FROM position") if err != nil { return fmt.Errorf("failed to clear positions: %s", err) } for _, p := range positions { // Resolve company_id if missing if p.CompanyID == 0 { company, err := GetCompanyBySymbol(db, p.Symbol) if err != nil { return fmt.Errorf("could not find company %s: %s", p.Symbol, err) } p.CompanyID = company.ID } // Resolve currency_id if missing if p.CurrencyID == 0 { currency, err := GetCurrencyByCode(db, p.CurrencyCode) if err != nil { return fmt.Errorf("could not find currency %s: %s", p.CurrencyCode, err) } p.CurrencyID = currency.ID } _, err := db.Exec(` INSERT INTO position (company_id, symbol, currency_id, currency_code, shares, weight, cost_basis) VALUES (?, ?, ?, ?, ?, ?, ?) ON CONFLICT(company_id) DO UPDATE SET symbol = excluded.symbol, currency_id = excluded.currency_id, currency_code = excluded.currency_code, shares = excluded.shares, weight = excluded.weight, cost_basis = excluded.cost_basis `, p.CompanyID, p.Symbol, p.CurrencyID, p.CurrencyCode, p.Shares, p.Weight, p.CostBasis, ) if err != nil { return fmt.Errorf("failed to upsert position %s: %s", p.Symbol, err) } } return nil }