diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b786662 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +data/** \ No newline at end of file diff --git a/Portifolio b/Portifolio index 8251ede..641f399 100755 Binary files a/Portifolio and b/Portifolio differ diff --git a/app.db b/app.db index e965b08..7f9fbe3 100644 Binary files a/app.db and b/app.db differ diff --git a/internal/database/company.go b/internal/database/company.go index 6be510c..c5bd023 100644 --- a/internal/database/company.go +++ b/internal/database/company.go @@ -11,9 +11,9 @@ import ( func GetCompanyByID(db *sql.DB, id int) (*model.Company, error) { var c model.Company err := db.QueryRow( - `SELECT id, name, shares_outstanding, price, currency_id FROM companies WHERE id = ?`, + `SELECT id, symbol, shares_outstanding, price, currency_id FROM companies WHERE id = ?`, id, - ).Scan(&c.ID, &c.Name, &c.SharesOutstanding, &c.Price, &c.CurrencyID) + ).Scan(&c.ID, &c.Symbol, &c.SharesOutstanding, &c.Price, &c.CurrencyID) if err == sql.ErrNoRows { return nil, nil @@ -23,3 +23,49 @@ func GetCompanyByID(db *sql.DB, id int) (*model.Company, error) { } return &c, nil } + +func AddCompany(db *sql.DB, input model.CompanyInput) (int, error) { + if input.CurrencyID == 0 { + if input.CurrencyCode != "" { + currency, err := GetCurrencyByCode(db, input.CurrencyCode) + if err != nil { + return 0, fmt.Errorf("could not get currency: %s", err) + } + input.CurrencyID = currency.ID + } else { + return 0, fmt.Errorf("no currency reference") + } + } + + res, err := db.Exec( + `INSERT INTO companies (symbol, shares_outstanding, price, currency_id) VALUES (?, ?, ?, ?)`, + input.Symbol, input.SharesOutstanding, input.Price, input.CurrencyID, + ) + if err != nil { + return 0, fmt.Errorf("failed to insert: %s", err) + } + id, err := res.LastInsertId() + return int(id), err +} + +func GetAllCompanies(db *sql.DB) ([]model.Company, error) { + rows, err := db.Query(` + SELECT id, symbol, shares_outstanding, price, currency_id FROM companies + `) + if err != nil { + return nil, err + } + defer rows.Close() + + var companies []model.Company + for rows.Next() { + var c model.Company + if err := rows.Scan( + &c.ID, &c.Symbol, &c.SharesOutstanding, &c.Price, &c.CurrencyID, + ); err != nil { + return nil, err + } + companies = append(companies, c) + } + return companies, rows.Err() +} diff --git a/internal/database/currency.go b/internal/database/currency.go index 5b2515e..b3969da 100644 --- a/internal/database/currency.go +++ b/internal/database/currency.go @@ -20,14 +20,22 @@ func GetCurrencyByID(db *sql.DB, ID int) (model.Currency, error) { return c, nil } +/* + CREATE TABLE IF NOT EXISTS currencies ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + code TEXT NOT NULL UNIQUE, + name TEXT NOT NULL + ); +*/ + func GetCurrencyByCode(db *sql.DB, Code string) (model.Currency, error) { var c model.Currency err := db.QueryRow( - `SELECT id, code, name, FROM currencies WHERE code = ?`, + `SELECT id, code, name FROM currencies WHERE code = ?`, Code, ).Scan(&c.ID, &c.Code, &c.Name) if err == sql.ErrNoRows { - return c, fmt.Errorf("company %d not found", Code) + return c, fmt.Errorf("company %s not found", Code) } return c, nil } diff --git a/internal/database/main.go b/internal/database/main.go index fe46214..06f51a3 100644 --- a/internal/database/main.go +++ b/internal/database/main.go @@ -18,7 +18,7 @@ func InitDB(db *sql.DB) { CREATE TABLE IF NOT EXISTS trades ( id INTEGER PRIMARY KEY AUTOINCREMENT, - company_id INTEGER NOT NULL, + symbol TEXT NOT NULL, currency_code TEXT NOT NULL, shares INTEGER NOT NULL, product INTEGER NOT NULL CHECK(product IN (0, 1, 2, 3)), @@ -29,16 +29,20 @@ func InitDB(db *sql.DB) { CREATE TABLE IF NOT EXISTS position ( id INTEGER PRIMARY KEY AUTOINCREMENT, - company_id INTEGER NOT NULL, + company_id INTEGER NOT NULL UNIQUE, + symbol TEXT NOT NULL, currency_id INTEGER NOT NULL, + currency_code TEXT NOT NULL, shares INTEGER NOT NULL, weight REAL NOT NULL, - cost_basis REAL NOT NULL + cost_basis REAL NOT NULL, + FOREIGN KEY (currency_id) REFERENCES currencies(id), + FOREIGN KEY (company_id) REFERENCES companies(id) ); CREATE TABLE IF NOT EXISTS companies ( id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL UNIQUE, + symbol TEXT NOT NULL UNIQUE, shares_outstanding INTEGER NOT NULL, price REAL NOT NULL, currency_id INTEGER NOT NULL, @@ -84,68 +88,3 @@ func InitDB(db *sql.DB) { } fmt.Println("Tables ready") } - -func MigrateAddUniqueToRevenueEntries(db *sql.DB) error { - steps := []string{ - // 1. copy existing data into a temp table with the new constraint - `CREATE TABLE IF NOT EXISTS revenue_entries_new ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - company_id INTEGER NOT NULL, - currency_id INTEGER NOT NULL, - category_id INTEGER NOT NULL, - period_id INTEGER NOT NULL, - value REAL NOT NULL, - FOREIGN KEY (company_id) REFERENCES companies(id), - FOREIGN KEY (currency_id) REFERENCES currencies(id), - FOREIGN KEY (category_id) REFERENCES category(id), - FOREIGN KEY (period_id) REFERENCES periods(id), - UNIQUE(company_id, category_id, period_id) - )`, - // 2. copy data over - `INSERT OR IGNORE INTO revenue_entries_new - SELECT id, company_id, currency_id, category_id, period_id, value - FROM revenue_entries`, - // 3. drop old table - `DROP TABLE revenue_entries`, - // 4. rename new table - `ALTER TABLE revenue_entries_new RENAME TO revenue_entries`, - } - - for _, step := range steps { - if _, err := db.Exec(step); err != nil { - return fmt.Errorf("migration failed: %w", err) - } - } - return nil -} - -func MigrateTradeCode(db *sql.DB) error { - step := - // 1. copy existing data into a temp table with the new constraint - ` - ALTER TABLE trades RENAME TO trades_old; - -CREATE TABLE IF NOT EXISTS trades ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - company_id INTEGER NOT NULL, - currency_code TEXT NOT NULL, - shares INTEGER NOT NULL, - product INTEGER NOT NULL CHECK(product IN (0, 1, 2, 3)), - type INTEGER NOT NULL CHECK(type IN (0, 1)), - price REAL NOT NULL, - traded_at DATETIME NOT NULL -); - -INSERT INTO trades (id, company_id, currency_code, shares, product, type, price, traded_at) -SELECT id, company_id, '', shares, product, type, price, traded_at -FROM trades_old; - -DROP TABLE trades_old; - ` - - _, err := db.Exec(step) - if err != nil { - return fmt.Errorf("migration failed: %w", err) - } - return nil -} diff --git a/internal/database/portfolio.go b/internal/database/portfolio.go index 60d122d..af9c5e5 100644 --- a/internal/database/portfolio.go +++ b/internal/database/portfolio.go @@ -9,7 +9,7 @@ import ( ) func GetTrades(db *sql.DB) ([]model.Trade, error) { - rows, err := db.Query("SELECT company_id, currency_code, shares, product, type, price, traded_at FROM trades") + rows, err := db.Query("SELECT company_id, symbol, currency_id, currency_code, shares, product, type, price, traded_at FROM trades") if err != nil { return nil, err } @@ -17,21 +17,14 @@ func GetTrades(db *sql.DB) ([]model.Trade, error) { var trades []model.Trade for rows.Next() { - var tickerInt int var typeInt int var t model.Trade - err := rows.Scan(&tickerInt, &t.Currency, &t.Shares, &t.Product, &typeInt, &t.Price, &t.Date) + err := rows.Scan(&t.CompanyID, &t.Symbol, &t.CompanyID, &t.CurrencyCode, &t.Shares, &t.Product, &typeInt, &t.Price, &t.Date) if err != nil { return nil, err } - company, err := GetCompanyByID(db, tickerInt) - if err != nil { - return nil, err - } - t.Ticker = *company - switch typeInt { case 0: t.Type = model.TradeType(false) @@ -50,7 +43,7 @@ func GetTrades(db *sql.DB) ([]model.Trade, error) { } func GetPositions(db *sql.DB) ([]model.Position, error) { - rows, err := db.Query("SELECT company_id, shares, weight, CostBases, currency_id") + rows, err := db.Query("SELECT company_id, symbol, shares, weight, CostBasis, currency_id, currency_code from position") if err != nil { return nil, err } @@ -59,10 +52,12 @@ func GetPositions(db *sql.DB) ([]model.Position, error) { var positions []model.Position for rows.Next() { var t model.Position - err := rows.Scan(&t.Company.ID, &t.Shares, &t.Weight, &t.CostBasis, t.Currency) + + err := rows.Scan(&t.CompanyID, &t.Symbol, &t.Shares, &t.Weight, &t.CostBasis, &t.CurrencyID, &t.CurrencyCode) if err != nil { return nil, err } + positions = append(positions, t) } if err = rows.Err(); err != nil { @@ -74,9 +69,11 @@ func GetPositions(db *sql.DB) ([]model.Position, error) { func InsertTrade(db *sql.DB, trade model.Trade) error { _, err := db.Exec( - "INSERT INTO trades (company_id, currency_id, shares, product, type, price, traded_at) VALUES (?, ?, ?, ?, ?, ?, ?)", - trade.Ticker.ID, - trade.Currency, + "INSERT INTO trades (company_id, symbol, currency_id, currency_code, shares, product, type, price, traded_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + trade.CompanyID, + trade.Symbol, + trade.CurrencyID, + trade.CurrencyCode, trade.Shares, trade.Product, trade.Type, diff --git a/internal/handlers/main.go b/internal/handlers/main.go index 4162c3e..8deaae8 100644 --- a/internal/handlers/main.go +++ b/internal/handlers/main.go @@ -1,8 +1,8 @@ package handlers import ( + "Portifolio/internal/database" "Portifolio/internal/model" - "Portifolio/internal/service" "database/sql" "encoding/json" "net/http" @@ -79,11 +79,11 @@ func AddCompanyHandler(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var input model.CompanyInput if err := json.NewDecoder(r.Body).Decode(&input); err != nil { - http.Error(w, "invalid json", http.StatusBadRequest) + http.Error(w, "invalid request body", http.StatusBadRequest) return } - id, err := service.InsertCompany(db, input) + id, err := database.AddCompany(db, input) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -91,13 +91,13 @@ func AddCompanyHandler(db *sql.DB) http.HandlerFunc { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) - json.NewEncoder(w).Encode(map[string]any{"status": "created", "id": id}) + json.NewEncoder(w).Encode(map[string]int{"id": id}) } } func GetCompaniesHandler(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - companies, err := service.GetAllCompanies(db) + companies, err := database.GetAllCompanies(db) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/internal/handlers/portfolio.go b/internal/handlers/portfolio.go index 72641e8..4279369 100644 --- a/internal/handlers/portfolio.go +++ b/internal/handlers/portfolio.go @@ -40,13 +40,13 @@ func AddTradeHandler(db *sql.DB) http.HandlerFunc { } trade := model.Trade{ - Ticker: *company, - Shares: req.Shares, - Product: model.TradeProduct(req.Product), - Type: model.TradeType(req.Type), - Price: req.Price, - Currency: currency.Name, - Date: req.Date, + Symbol: company.Symbol, + Shares: req.Shares, + Product: model.TradeProduct(req.Product), + Type: model.TradeType(req.Type), + Price: req.Price, + CurrencyCode: currency.Code, + Date: req.Date, } database.InsertTrade(db, trade) diff --git a/internal/model/company.go b/internal/model/company.go index 43d6bb4..cf8515e 100644 --- a/internal/model/company.go +++ b/internal/model/company.go @@ -2,15 +2,16 @@ package model type Company struct { ID int - Name string + Symbol string SharesOutstanding int Price float64 CurrencyID int } type CompanyInput struct { - Name string `json:"name"` + Symbol string `json:"symbol"` SharesOutstanding int `json:"shares_outstanding"` Price float64 `json:"price"` CurrencyID int `json:"currency_id"` + CurrencyCode string `json:"currency_code"` } diff --git a/internal/model/portifolio.go b/internal/model/portifolio.go index 7ab6172..371586f 100644 --- a/internal/model/portifolio.go +++ b/internal/model/portifolio.go @@ -6,11 +6,13 @@ import ( ) type Position struct { - Company Company - Currency Currency - Weight float64 - CostBasis float64 - Shares int + CompanyID int + Symbol string + CurrencyCode string + CurrencyID int + Weight float64 + CostBasis float64 + Shares int } type TradeProduct int @@ -31,13 +33,15 @@ const ( ) type Trade struct { - Ticker Company - Shares int - Product TradeProduct - Type TradeType - Price float64 - Currency string - Date time.Time + CompanyID int + Symbol string + CurrencyID int + CurrencyCode string + Shares int + Product TradeProduct + Type TradeType + Price float64 + Date time.Time } type AddTradeRequest struct { diff --git a/internal/service/company.go b/internal/service/company.go deleted file mode 100644 index ed3ef16..0000000 --- a/internal/service/company.go +++ /dev/null @@ -1,49 +0,0 @@ -package service - -import ( - "Portifolio/internal/model" - "database/sql" - - _ "github.com/mattn/go-sqlite3" -) - -func InsertCompany(db *sql.DB, input model.CompanyInput) (int, error) { - res, err := db.Exec( - `INSERT INTO companies (name, shares_outstanding, price, currency_id) VALUES (?, ?, ?, ?)`, - input.Name, input.SharesOutstanding, input.Price, input.CurrencyID, - ) - if err != nil { - return 0, err - } - id, err := res.LastInsertId() - return int(id), err -} - -func GetAllCompanies(db *sql.DB) ([]model.Company, error) { - rows, err := db.Query(` - SELECT c.id, c.name, c.shares_outstanding, c.price, - cu.id, cu.code, cu.name - FROM companies c - JOIN currencies cu ON c.currency_id = cu.id - ORDER BY c.name - `) - if err != nil { - return nil, err - } - defer rows.Close() - - var companies []model.Company - for rows.Next() { - var c model.Company - var cu model.Currency - if err := rows.Scan( - &c.ID, &c.Name, &c.SharesOutstanding, &c.Price, - &cu.ID, &cu.Code, &cu.Name, - ); err != nil { - return nil, err - } - c.CurrencyID = cu.ID - companies = append(companies, c) - } - return companies, rows.Err() -} diff --git a/internal/service/main.go b/internal/service/main.go deleted file mode 100644 index 5699681..0000000 --- a/internal/service/main.go +++ /dev/null @@ -1,16 +0,0 @@ -package service - -import ( - "Portifolio/internal/model" - "database/sql" - - _ "github.com/mattn/go-sqlite3" -) - -func AddCompany(input model.CompanyInput, db *sql.DB) error { - _, err := db.Exec( - `INSERT INTO companies (name, shares_outstanding, price, currency_id) VALUES (?, ?, ?, ?)`, - input.Name, input.SharesOutstanding, input.Price, input.CurrencyID, - ) - return err -} diff --git a/internal/shell/company.go b/internal/shell/company.go index a3c1881..8158b0f 100644 --- a/internal/shell/company.go +++ b/internal/shell/company.go @@ -3,7 +3,6 @@ package shell import ( "Portifolio/internal/database" "Portifolio/internal/model" - "Portifolio/internal/service" "bufio" "database/sql" "fmt" @@ -16,9 +15,9 @@ import ( func AddCompany(scanner *bufio.Scanner, db *sql.DB) { input := model.CompanyInput{} - fmt.Print(" Name: ") + fmt.Print(" symbol: ") scanner.Scan() - input.Name = strings.TrimSpace(scanner.Text()) + input.Symbol = strings.TrimSpace(scanner.Text()) fmt.Print(" Shares outstanding: ") scanner.Scan() @@ -38,24 +37,19 @@ func AddCompany(scanner *bufio.Scanner, db *sql.DB) { } input.Price = price - fmt.Print(" Currency ID: ") + fmt.Print(" Currency Code: ") scanner.Scan() - cid, err := strconv.Atoi(strings.TrimSpace(scanner.Text())) - if err != nil { - fmt.Println(" Invalid currency ID.") - return - } - input.CurrencyID = cid + input.CurrencyCode = strings.TrimSpace(scanner.Text()) - if err := service.AddCompany(input, db); err != nil { + if _, err := database.AddCompany(db, input); err != nil { fmt.Println(" Error:", err) return } - fmt.Printf(" ✓ Company '%s' added.\n", input.Name) + fmt.Printf(" ✓ Company '%s' added.\n", input.Symbol) } func ListCompanies(db *sql.DB) { - companies, err := service.GetAllCompanies(db) + companies, err := database.GetAllCompanies(db) if err != nil { fmt.Println(" ✗ Error:", err) return @@ -77,6 +71,6 @@ func ListCompanies(db *sql.DB) { } fmt.Printf(" %-5d %-20s %-10s %-15.2f %d\n", - c.ID, c.Name, currency, c.Price, c.SharesOutstanding) + c.ID, c.Symbol, currency.Code, c.Price, c.SharesOutstanding) } } diff --git a/main.go b/main.go index f19970f..a76d0fb 100644 --- a/main.go +++ b/main.go @@ -30,8 +30,6 @@ func main() { } database.InitDB(db) - database.MigrateAddUniqueToRevenueEntries(db) - database.MigrateTradeCode(db) fmt.Println("Connected to SQLite database") http.HandleFunc("/health", handlers.HealthHandler(db)) @@ -47,7 +45,8 @@ func main() { http.HandleFunc("GET /company/revenue/categories", handlers.GetCompanyRevenueCategories(db)) // Currency - http.HandleFunc("GET /currencies", handlers.GetCurrenciesHandler(db)) + http.HandleFunc("GET /currency/list", handlers.GetCurrenciesHandler(db)) + http.HandleFunc("POST /currency/add", handlers.AddCurrencyHandler(db)) // Revenue http.HandleFunc("POST /add/revenue/entry", handlers.AddRevenueEntryHandler(db))