better trade endpoints

This commit is contained in:
samantha42
2026-03-25 22:17:19 +01:00
parent d491b9c14c
commit ff7b41e2a8
7 changed files with 51 additions and 26 deletions

Binary file not shown.

BIN
app.db

Binary file not shown.

View File

@@ -19,7 +19,7 @@ func InitDB(db *sql.DB) {
CREATE TABLE IF NOT EXISTS trades ( CREATE TABLE IF NOT EXISTS trades (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
company_id INTEGER NOT NULL, company_id INTEGER NOT NULL,
currency_id INTEGER NOT NULL, currency_code TEXT NOT NULL,
shares INTEGER NOT NULL, shares INTEGER NOT NULL,
product INTEGER NOT NULL CHECK(product IN (0, 1, 2, 3)), product INTEGER NOT NULL CHECK(product IN (0, 1, 2, 3)),
type INTEGER NOT NULL CHECK(type IN (0, 1)), type INTEGER NOT NULL CHECK(type IN (0, 1)),
@@ -118,3 +118,34 @@ func MigrateAddUniqueToRevenueEntries(db *sql.DB) error {
} }
return nil return nil
} }
func MigrateTradeCode(db *sql.DB) error {
step :=
// 1. copy existing data into a temp table with the new constraint
`
ALTER TABLE trades RENAME TO trades_old;
CREATE TABLE IF NOT EXISTS trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
company_id INTEGER NOT NULL,
currency_code TEXT NOT NULL,
shares INTEGER NOT NULL,
product INTEGER NOT NULL CHECK(product IN (0, 1, 2, 3)),
type INTEGER NOT NULL CHECK(type IN (0, 1)),
price REAL NOT NULL,
traded_at DATETIME NOT NULL
);
INSERT INTO trades (id, company_id, currency_code, shares, product, type, price, traded_at)
SELECT id, company_id, '', shares, product, type, price, traded_at
FROM trades_old;
DROP TABLE trades_old;
`
_, err := db.Exec(step)
if err != nil {
return fmt.Errorf("migration failed: %w", err)
}
return nil
}

View File

@@ -9,36 +9,30 @@ import (
) )
func GetTrades(db *sql.DB) ([]model.Trade, error) { func GetTrades(db *sql.DB) ([]model.Trade, error) {
rows, err := db.Query("SELECT company_id, currency_id, shares, product, type, price, traded_at FROM trades") rows, err := db.Query("SELECT company_id, currency_code, shares, product, type, price, traded_at FROM trades")
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer rows.Close()
var trades []model.Trade var trades []model.Trade
for rows.Next() { for rows.Next() {
var TickerInt int var tickerInt int
var CurrencyInt int var typeInt int
var TypeInt int
var t model.Trade var t model.Trade
err := rows.Scan(&TickerInt, &CurrencyInt, &t.Shares, &t.Product, &TypeInt, &t.Price, &t.Date)
err := rows.Scan(&tickerInt, &t.Currency, &t.Shares, &t.Product, &typeInt, &t.Price, &t.Date)
if err != nil { if err != nil {
return nil, err return nil, err
} }
company, err := GetCompanyByID(db, TickerInt) company, err := GetCompanyByID(db, tickerInt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
currency, err := GetCurrencyByID(db, CurrencyInt)
if err != nil {
return nil, err
}
t.Currency = currency.Name
t.Ticker = *company t.Ticker = *company
switch TypeInt {
switch typeInt {
case 0: case 0:
t.Type = model.TradeType(false) t.Type = model.TradeType(false)
case 1: case 1:
@@ -52,7 +46,6 @@ func GetTrades(db *sql.DB) ([]model.Trade, error) {
if err = rows.Err(); err != nil { if err = rows.Err(); err != nil {
return nil, err return nil, err
} }
return trades, nil return trades, nil
} }

View File

@@ -26,7 +26,7 @@ func AddTradeHandler(db *sql.DB) http.HandlerFunc {
} }
// check if currency is in the db. // check if currency is in the db.
currency, err := database.GetCurrencyByCode(db, req.Currency) currency, err := database.GetCurrencyByCode(db, req.CurrencyCode)
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("failed to find currency: %s", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to find currency: %s", err), http.StatusInternalServerError)
return return
@@ -57,7 +57,7 @@ func GetTradeListHandler(db *sql.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
tradeList, err := database.GetTrades(db) tradeList, err := database.GetTrades(db)
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("failed to fetch trades:", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to fetch trades: %s", err), http.StatusInternalServerError)
return return
} }

View File

@@ -46,7 +46,7 @@ type AddTradeRequest struct {
Product int `json:"product"` Product int `json:"product"`
Type bool `json:"type"` Type bool `json:"type"`
Price float64 `json:"price"` Price float64 `json:"price"`
Currency string `json:"currency"` CurrencyCode string `json:"currency_code"`
Date time.Time `json:"date"` Date time.Time `json:"date"`
} }
@@ -63,7 +63,7 @@ func (r *AddTradeRequest) Validate() error {
if r.Price <= 0 { if r.Price <= 0 {
return errors.New("price must be a positive number") return errors.New("price must be a positive number")
} }
if r.Currency == "" { if r.CurrencyCode != "" {
return errors.New("currency is required") return errors.New("currency is required")
} }
if r.Date.IsZero() { if r.Date.IsZero() {

View File

@@ -31,6 +31,7 @@ func main() {
database.InitDB(db) database.InitDB(db)
database.MigrateAddUniqueToRevenueEntries(db) database.MigrateAddUniqueToRevenueEntries(db)
database.MigrateTradeCode(db)
fmt.Println("Connected to SQLite database") fmt.Println("Connected to SQLite database")
http.HandleFunc("/health", handlers.HealthHandler(db)) http.HandleFunc("/health", handlers.HealthHandler(db))