diff --git a/Portifolio b/Portifolio index d5239e6..c4e0c7d 100755 Binary files a/Portifolio and b/Portifolio differ diff --git a/app.db b/app.db index a0c9090..a8a8782 100644 Binary files a/app.db and b/app.db differ diff --git a/internal/database/currency.go b/internal/database/currency.go index e5d9fe0..5b2515e 100644 --- a/internal/database/currency.go +++ b/internal/database/currency.go @@ -19,3 +19,15 @@ func GetCurrencyByID(db *sql.DB, ID int) (model.Currency, error) { } return c, nil } + +func GetCurrencyByCode(db *sql.DB, Code string) (model.Currency, error) { + var c model.Currency + err := db.QueryRow( + `SELECT id, code, name, FROM currencies WHERE code = ?`, + Code, + ).Scan(&c.ID, &c.Code, &c.Name) + if err == sql.ErrNoRows { + return c, fmt.Errorf("company %d not found", Code) + } + return c, nil +} diff --git a/internal/database/main.go b/internal/database/main.go index 36603c1..42e71b8 100644 --- a/internal/database/main.go +++ b/internal/database/main.go @@ -23,17 +23,17 @@ func InitDB(db *sql.DB) { shares INTEGER NOT NULL, product INTEGER NOT NULL CHECK(product IN (0, 1, 2, 3)), type INTEGER NOT NULL CHECK(type IN (0, 1)), - price REAL NOT NULL + price REAL NOT NULL, traded_at DATETIME NOT NULL ); CREATE TABLE IF NOT EXISTS position ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - company_id INTEGER NOT NULL, + id INTEGER PRIMARY KEY AUTOINCREMENT, + company_id INTEGER NOT NULL, currency_id INTEGER NOT NULL, - shares INTEGER NOT NULL, - weight REAL NOT NULL, - CostBases REAL NOT NULL, + shares INTEGER NOT NULL, + weight REAL NOT NULL, + cost_basis REAL NOT NULL ); CREATE TABLE IF NOT EXISTS companies ( diff --git a/internal/database/portfolio.go b/internal/database/portfolio.go index 7483c85..174f0ce 100644 --- a/internal/database/portfolio.go +++ b/internal/database/portfolio.go @@ -3,24 +3,50 @@ package database import ( "Portifolio/internal/model" "database/sql" + "fmt" _ "github.com/mattn/go-sqlite3" ) func GetTrades(db *sql.DB) ([]model.Trade, error) { - rows, err := db.Query("SELECT id, company_id, currency_id, shares, product, type, price, traded_at FROM trades") + rows, err := db.Query("SELECT company_id, currency_id, shares, product, type, price, traded_at FROM trades") if err != nil { return nil, err } defer rows.Close() var trades []model.Trade + for rows.Next() { + var TickerInt int + var CurrencyInt int + var TypeInt int + var t model.Trade - err := rows.Scan(&t.Ticker, &t.Currency, &t.Shares, &t.Product, &t.Type, &t.Price, &t.Date) + err := rows.Scan(&TickerInt, &CurrencyInt, &t.Shares, &t.Product, &TypeInt, &t.Price, &t.Date) if err != nil { return nil, err } + + company, err := GetCompanyByID(db, TickerInt) + if err != nil { + return nil, err + } + currency, err := GetCurrencyByID(db, CurrencyInt) + if err != nil { + return nil, err + } + t.Currency = currency + t.Ticker = *company + switch TypeInt { + case 0: + t.Type = model.TradeType(false) + case 1: + t.Type = model.TradeType(true) + default: + return nil, fmt.Errorf("failed to convert given Type int to bool of trade type.") + } + trades = append(trades, t) } if err = rows.Err(); err != nil { diff --git a/internal/handlers/portfolio.go b/internal/handlers/portfolio.go index 590aac5..58b54bb 100644 --- a/internal/handlers/portfolio.go +++ b/internal/handlers/portfolio.go @@ -5,6 +5,7 @@ import ( "Portifolio/internal/model" "database/sql" "encoding/json" + "fmt" "net/http" _ "github.com/mattn/go-sqlite3" @@ -20,10 +21,35 @@ func AddTradeHandler(db *sql.DB) http.HandlerFunc { err := req.Validate() if err != nil { - http.Error(w, "invalid json", http.StatusBadRequest) + http.Error(w, fmt.Sprintf("failed to validate trade: %s", err), http.StatusInternalServerError) return } + // check if currency is in the db. + currency, err := database.GetCurrencyByCode(db, req.Currency) + if err != nil { + http.Error(w, fmt.Sprintf("failed to find currency: %s", err), http.StatusInternalServerError) + return + } + + // check if company is in the db. + company, err := database.GetCompanyByID(db, req.TickerId) + if err != nil { + http.Error(w, fmt.Sprintf("failed to find currency: %s", err), http.StatusInternalServerError) + return + } + + trade := model.Trade{ + Ticker: *company, + Shares: req.Shares, + Product: model.TradeProduct(req.Product), + Type: model.TradeType(req.Type), + Price: req.Price, + Currency: currency, + Date: req.Date, + } + + database.InsertTrade(db, trade) } } @@ -31,13 +57,13 @@ func GetTradeListHandler(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { tradeList, err := database.GetTrades(db) if err != nil { - http.Error(w, "failed to fetch trades", http.StatusInternalServerError) + http.Error(w, fmt.Sprintf("failed to fetch trades:", err), http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(tradeList); err != nil { - http.Error(w, "failed to encode trades", http.StatusInternalServerError) + http.Error(w, fmt.Sprintf("failed to encode trades: %s", err), http.StatusInternalServerError) return } } diff --git a/internal/model/portifolio.go b/internal/model/portifolio.go index 598c075..d807613 100644 --- a/internal/model/portifolio.go +++ b/internal/model/portifolio.go @@ -41,13 +41,13 @@ type Trade struct { } type AddTradeRequest struct { - TickerId int - Shares int - Product int - Type bool - Price float64 - Currency string - Date time.Time + TickerId int `json:"ticker_id"` + Shares int `json:"shares"` + Product int `json:"product"` + Type bool `json:"type"` + Price float64 `json:"price"` + Currency string `json:"currency"` + Date time.Time `json:"date"` } func (r *AddTradeRequest) Validate() error {