diff --git a/Portifolio b/Portifolio index 61e019a..8251ede 100755 Binary files a/Portifolio and b/Portifolio differ diff --git a/app.db b/app.db index f8ea132..e965b08 100644 Binary files a/app.db and b/app.db differ diff --git a/internal/database/main.go b/internal/database/main.go index 42e71b8..fe46214 100644 --- a/internal/database/main.go +++ b/internal/database/main.go @@ -19,7 +19,7 @@ func InitDB(db *sql.DB) { CREATE TABLE IF NOT EXISTS trades ( id INTEGER PRIMARY KEY AUTOINCREMENT, company_id INTEGER NOT NULL, - currency_id INTEGER NOT NULL, + currency_code TEXT NOT NULL, shares INTEGER NOT NULL, product INTEGER NOT NULL CHECK(product IN (0, 1, 2, 3)), type INTEGER NOT NULL CHECK(type IN (0, 1)), @@ -118,3 +118,34 @@ func MigrateAddUniqueToRevenueEntries(db *sql.DB) error { } return nil } + +func MigrateTradeCode(db *sql.DB) error { + step := + // 1. copy existing data into a temp table with the new constraint + ` + ALTER TABLE trades RENAME TO trades_old; + +CREATE TABLE IF NOT EXISTS trades ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + company_id INTEGER NOT NULL, + currency_code TEXT NOT NULL, + shares INTEGER NOT NULL, + product INTEGER NOT NULL CHECK(product IN (0, 1, 2, 3)), + type INTEGER NOT NULL CHECK(type IN (0, 1)), + price REAL NOT NULL, + traded_at DATETIME NOT NULL +); + +INSERT INTO trades (id, company_id, currency_code, shares, product, type, price, traded_at) +SELECT id, company_id, '', shares, product, type, price, traded_at +FROM trades_old; + +DROP TABLE trades_old; + ` + + _, err := db.Exec(step) + if err != nil { + return fmt.Errorf("migration failed: %w", err) + } + return nil +} diff --git a/internal/database/portfolio.go b/internal/database/portfolio.go index d548e8f..60d122d 100644 --- a/internal/database/portfolio.go +++ b/internal/database/portfolio.go @@ -9,36 +9,30 @@ import ( ) func GetTrades(db *sql.DB) ([]model.Trade, error) { - rows, err := db.Query("SELECT company_id, currency_id, shares, product, type, price, traded_at FROM trades") + rows, err := db.Query("SELECT company_id, currency_code, shares, product, type, price, traded_at FROM trades") if err != nil { return nil, err } defer rows.Close() var trades []model.Trade - for rows.Next() { - var TickerInt int - var CurrencyInt int - var TypeInt int - + var tickerInt int + var typeInt int var t model.Trade - err := rows.Scan(&TickerInt, &CurrencyInt, &t.Shares, &t.Product, &TypeInt, &t.Price, &t.Date) + + err := rows.Scan(&tickerInt, &t.Currency, &t.Shares, &t.Product, &typeInt, &t.Price, &t.Date) if err != nil { return nil, err } - company, err := GetCompanyByID(db, TickerInt) + company, err := GetCompanyByID(db, tickerInt) if err != nil { return nil, err } - currency, err := GetCurrencyByID(db, CurrencyInt) - if err != nil { - return nil, err - } - t.Currency = currency.Name t.Ticker = *company - switch TypeInt { + + switch typeInt { case 0: t.Type = model.TradeType(false) case 1: @@ -52,7 +46,6 @@ func GetTrades(db *sql.DB) ([]model.Trade, error) { if err = rows.Err(); err != nil { return nil, err } - return trades, nil } diff --git a/internal/handlers/portfolio.go b/internal/handlers/portfolio.go index 5673929..72641e8 100644 --- a/internal/handlers/portfolio.go +++ b/internal/handlers/portfolio.go @@ -26,7 +26,7 @@ func AddTradeHandler(db *sql.DB) http.HandlerFunc { } // check if currency is in the db. - currency, err := database.GetCurrencyByCode(db, req.Currency) + currency, err := database.GetCurrencyByCode(db, req.CurrencyCode) if err != nil { http.Error(w, fmt.Sprintf("failed to find currency: %s", err), http.StatusInternalServerError) return @@ -57,7 +57,7 @@ func GetTradeListHandler(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { tradeList, err := database.GetTrades(db) if err != nil { - http.Error(w, fmt.Sprintf("failed to fetch trades:", err), http.StatusInternalServerError) + http.Error(w, fmt.Sprintf("failed to fetch trades: %s", err), http.StatusInternalServerError) return } diff --git a/internal/model/portifolio.go b/internal/model/portifolio.go index 553c084..7ab6172 100644 --- a/internal/model/portifolio.go +++ b/internal/model/portifolio.go @@ -41,13 +41,13 @@ type Trade struct { } type AddTradeRequest struct { - TickerId int `json:"ticker_id"` - Shares int `json:"shares"` - Product int `json:"product"` - Type bool `json:"type"` - Price float64 `json:"price"` - Currency string `json:"currency"` - Date time.Time `json:"date"` + TickerId int `json:"ticker_id"` + Shares int `json:"shares"` + Product int `json:"product"` + Type bool `json:"type"` + Price float64 `json:"price"` + CurrencyCode string `json:"currency_code"` + Date time.Time `json:"date"` } func (r *AddTradeRequest) Validate() error { @@ -63,7 +63,7 @@ func (r *AddTradeRequest) Validate() error { if r.Price <= 0 { return errors.New("price must be a positive number") } - if r.Currency == "" { + if r.CurrencyCode != "" { return errors.New("currency is required") } if r.Date.IsZero() { diff --git a/main.go b/main.go index 744ac0b..f19970f 100644 --- a/main.go +++ b/main.go @@ -31,6 +31,7 @@ func main() { database.InitDB(db) database.MigrateAddUniqueToRevenueEntries(db) + database.MigrateTradeCode(db) fmt.Println("Connected to SQLite database") http.HandleFunc("/health", handlers.HealthHandler(db))