changes to model

This commit is contained in:
samantha42
2026-03-26 08:39:42 +01:00
parent ff7b41e2a8
commit 3f878c1dc0
15 changed files with 119 additions and 195 deletions

1
.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
data/**

Binary file not shown.

BIN
app.db

Binary file not shown.

View File

@@ -11,9 +11,9 @@ import (
func GetCompanyByID(db *sql.DB, id int) (*model.Company, error) { func GetCompanyByID(db *sql.DB, id int) (*model.Company, error) {
var c model.Company var c model.Company
err := db.QueryRow( err := db.QueryRow(
`SELECT id, name, shares_outstanding, price, currency_id FROM companies WHERE id = ?`, `SELECT id, symbol, shares_outstanding, price, currency_id FROM companies WHERE id = ?`,
id, id,
).Scan(&c.ID, &c.Name, &c.SharesOutstanding, &c.Price, &c.CurrencyID) ).Scan(&c.ID, &c.Symbol, &c.SharesOutstanding, &c.Price, &c.CurrencyID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
@@ -23,3 +23,49 @@ func GetCompanyByID(db *sql.DB, id int) (*model.Company, error) {
} }
return &c, nil return &c, nil
} }
func AddCompany(db *sql.DB, input model.CompanyInput) (int, error) {
if input.CurrencyID == 0 {
if input.CurrencyCode != "" {
currency, err := GetCurrencyByCode(db, input.CurrencyCode)
if err != nil {
return 0, fmt.Errorf("could not get currency: %s", err)
}
input.CurrencyID = currency.ID
} else {
return 0, fmt.Errorf("no currency reference")
}
}
res, err := db.Exec(
`INSERT INTO companies (symbol, shares_outstanding, price, currency_id) VALUES (?, ?, ?, ?)`,
input.Symbol, input.SharesOutstanding, input.Price, input.CurrencyID,
)
if err != nil {
return 0, fmt.Errorf("failed to insert: %s", err)
}
id, err := res.LastInsertId()
return int(id), err
}
func GetAllCompanies(db *sql.DB) ([]model.Company, error) {
rows, err := db.Query(`
SELECT id, symbol, shares_outstanding, price, currency_id FROM companies
`)
if err != nil {
return nil, err
}
defer rows.Close()
var companies []model.Company
for rows.Next() {
var c model.Company
if err := rows.Scan(
&c.ID, &c.Symbol, &c.SharesOutstanding, &c.Price, &c.CurrencyID,
); err != nil {
return nil, err
}
companies = append(companies, c)
}
return companies, rows.Err()
}

View File

@@ -20,14 +20,22 @@ func GetCurrencyByID(db *sql.DB, ID int) (model.Currency, error) {
return c, nil return c, nil
} }
/*
CREATE TABLE IF NOT EXISTS currencies (
id INTEGER PRIMARY KEY AUTOINCREMENT,
code TEXT NOT NULL UNIQUE,
name TEXT NOT NULL
);
*/
func GetCurrencyByCode(db *sql.DB, Code string) (model.Currency, error) { func GetCurrencyByCode(db *sql.DB, Code string) (model.Currency, error) {
var c model.Currency var c model.Currency
err := db.QueryRow( err := db.QueryRow(
`SELECT id, code, name, FROM currencies WHERE code = ?`, `SELECT id, code, name FROM currencies WHERE code = ?`,
Code, Code,
).Scan(&c.ID, &c.Code, &c.Name) ).Scan(&c.ID, &c.Code, &c.Name)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return c, fmt.Errorf("company %d not found", Code) return c, fmt.Errorf("company %s not found", Code)
} }
return c, nil return c, nil
} }

View File

@@ -18,7 +18,7 @@ func InitDB(db *sql.DB) {
CREATE TABLE IF NOT EXISTS trades ( CREATE TABLE IF NOT EXISTS trades (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
company_id INTEGER NOT NULL, symbol TEXT NOT NULL,
currency_code TEXT NOT NULL, currency_code TEXT NOT NULL,
shares INTEGER NOT NULL, shares INTEGER NOT NULL,
product INTEGER NOT NULL CHECK(product IN (0, 1, 2, 3)), product INTEGER NOT NULL CHECK(product IN (0, 1, 2, 3)),
@@ -29,16 +29,20 @@ func InitDB(db *sql.DB) {
CREATE TABLE IF NOT EXISTS position ( CREATE TABLE IF NOT EXISTS position (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
company_id INTEGER NOT NULL, company_id INTEGER NOT NULL UNIQUE,
symbol TEXT NOT NULL,
currency_id INTEGER NOT NULL, currency_id INTEGER NOT NULL,
currency_code TEXT NOT NULL,
shares INTEGER NOT NULL, shares INTEGER NOT NULL,
weight REAL NOT NULL, weight REAL NOT NULL,
cost_basis REAL NOT NULL cost_basis REAL NOT NULL,
FOREIGN KEY (currency_id) REFERENCES currencies(id),
FOREIGN KEY (company_id) REFERENCES companies(id)
); );
CREATE TABLE IF NOT EXISTS companies ( CREATE TABLE IF NOT EXISTS companies (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE, symbol TEXT NOT NULL UNIQUE,
shares_outstanding INTEGER NOT NULL, shares_outstanding INTEGER NOT NULL,
price REAL NOT NULL, price REAL NOT NULL,
currency_id INTEGER NOT NULL, currency_id INTEGER NOT NULL,
@@ -84,68 +88,3 @@ func InitDB(db *sql.DB) {
} }
fmt.Println("Tables ready") fmt.Println("Tables ready")
} }
func MigrateAddUniqueToRevenueEntries(db *sql.DB) error {
steps := []string{
// 1. copy existing data into a temp table with the new constraint
`CREATE TABLE IF NOT EXISTS revenue_entries_new (
id INTEGER PRIMARY KEY AUTOINCREMENT,
company_id INTEGER NOT NULL,
currency_id INTEGER NOT NULL,
category_id INTEGER NOT NULL,
period_id INTEGER NOT NULL,
value REAL NOT NULL,
FOREIGN KEY (company_id) REFERENCES companies(id),
FOREIGN KEY (currency_id) REFERENCES currencies(id),
FOREIGN KEY (category_id) REFERENCES category(id),
FOREIGN KEY (period_id) REFERENCES periods(id),
UNIQUE(company_id, category_id, period_id)
)`,
// 2. copy data over
`INSERT OR IGNORE INTO revenue_entries_new
SELECT id, company_id, currency_id, category_id, period_id, value
FROM revenue_entries`,
// 3. drop old table
`DROP TABLE revenue_entries`,
// 4. rename new table
`ALTER TABLE revenue_entries_new RENAME TO revenue_entries`,
}
for _, step := range steps {
if _, err := db.Exec(step); err != nil {
return fmt.Errorf("migration failed: %w", err)
}
}
return nil
}
func MigrateTradeCode(db *sql.DB) error {
step :=
// 1. copy existing data into a temp table with the new constraint
`
ALTER TABLE trades RENAME TO trades_old;
CREATE TABLE IF NOT EXISTS trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
company_id INTEGER NOT NULL,
currency_code TEXT NOT NULL,
shares INTEGER NOT NULL,
product INTEGER NOT NULL CHECK(product IN (0, 1, 2, 3)),
type INTEGER NOT NULL CHECK(type IN (0, 1)),
price REAL NOT NULL,
traded_at DATETIME NOT NULL
);
INSERT INTO trades (id, company_id, currency_code, shares, product, type, price, traded_at)
SELECT id, company_id, '', shares, product, type, price, traded_at
FROM trades_old;
DROP TABLE trades_old;
`
_, err := db.Exec(step)
if err != nil {
return fmt.Errorf("migration failed: %w", err)
}
return nil
}

View File

@@ -9,7 +9,7 @@ import (
) )
func GetTrades(db *sql.DB) ([]model.Trade, error) { func GetTrades(db *sql.DB) ([]model.Trade, error) {
rows, err := db.Query("SELECT company_id, currency_code, shares, product, type, price, traded_at FROM trades") rows, err := db.Query("SELECT company_id, symbol, currency_id, currency_code, shares, product, type, price, traded_at FROM trades")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -17,21 +17,14 @@ func GetTrades(db *sql.DB) ([]model.Trade, error) {
var trades []model.Trade var trades []model.Trade
for rows.Next() { for rows.Next() {
var tickerInt int
var typeInt int var typeInt int
var t model.Trade var t model.Trade
err := rows.Scan(&tickerInt, &t.Currency, &t.Shares, &t.Product, &typeInt, &t.Price, &t.Date) err := rows.Scan(&t.CompanyID, &t.Symbol, &t.CompanyID, &t.CurrencyCode, &t.Shares, &t.Product, &typeInt, &t.Price, &t.Date)
if err != nil { if err != nil {
return nil, err return nil, err
} }
company, err := GetCompanyByID(db, tickerInt)
if err != nil {
return nil, err
}
t.Ticker = *company
switch typeInt { switch typeInt {
case 0: case 0:
t.Type = model.TradeType(false) t.Type = model.TradeType(false)
@@ -50,7 +43,7 @@ func GetTrades(db *sql.DB) ([]model.Trade, error) {
} }
func GetPositions(db *sql.DB) ([]model.Position, error) { func GetPositions(db *sql.DB) ([]model.Position, error) {
rows, err := db.Query("SELECT company_id, shares, weight, CostBases, currency_id") rows, err := db.Query("SELECT company_id, symbol, shares, weight, CostBasis, currency_id, currency_code from position")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -59,10 +52,12 @@ func GetPositions(db *sql.DB) ([]model.Position, error) {
var positions []model.Position var positions []model.Position
for rows.Next() { for rows.Next() {
var t model.Position var t model.Position
err := rows.Scan(&t.Company.ID, &t.Shares, &t.Weight, &t.CostBasis, t.Currency)
err := rows.Scan(&t.CompanyID, &t.Symbol, &t.Shares, &t.Weight, &t.CostBasis, &t.CurrencyID, &t.CurrencyCode)
if err != nil { if err != nil {
return nil, err return nil, err
} }
positions = append(positions, t) positions = append(positions, t)
} }
if err = rows.Err(); err != nil { if err = rows.Err(); err != nil {
@@ -74,9 +69,11 @@ func GetPositions(db *sql.DB) ([]model.Position, error) {
func InsertTrade(db *sql.DB, trade model.Trade) error { func InsertTrade(db *sql.DB, trade model.Trade) error {
_, err := db.Exec( _, err := db.Exec(
"INSERT INTO trades (company_id, currency_id, shares, product, type, price, traded_at) VALUES (?, ?, ?, ?, ?, ?, ?)", "INSERT INTO trades (company_id, symbol, currency_id, currency_code, shares, product, type, price, traded_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
trade.Ticker.ID, trade.CompanyID,
trade.Currency, trade.Symbol,
trade.CurrencyID,
trade.CurrencyCode,
trade.Shares, trade.Shares,
trade.Product, trade.Product,
trade.Type, trade.Type,

View File

@@ -1,8 +1,8 @@
package handlers package handlers
import ( import (
"Portifolio/internal/database"
"Portifolio/internal/model" "Portifolio/internal/model"
"Portifolio/internal/service"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"net/http" "net/http"
@@ -79,11 +79,11 @@ func AddCompanyHandler(db *sql.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var input model.CompanyInput var input model.CompanyInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil { if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest) http.Error(w, "invalid request body", http.StatusBadRequest)
return return
} }
id, err := service.InsertCompany(db, input) id, err := database.AddCompany(db, input)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
@@ -91,13 +91,13 @@ func AddCompanyHandler(db *sql.DB) http.HandlerFunc {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated) w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]any{"status": "created", "id": id}) json.NewEncoder(w).Encode(map[string]int{"id": id})
} }
} }
func GetCompaniesHandler(db *sql.DB) http.HandlerFunc { func GetCompaniesHandler(db *sql.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
companies, err := service.GetAllCompanies(db) companies, err := database.GetAllCompanies(db)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return

View File

@@ -40,13 +40,13 @@ func AddTradeHandler(db *sql.DB) http.HandlerFunc {
} }
trade := model.Trade{ trade := model.Trade{
Ticker: *company, Symbol: company.Symbol,
Shares: req.Shares, Shares: req.Shares,
Product: model.TradeProduct(req.Product), Product: model.TradeProduct(req.Product),
Type: model.TradeType(req.Type), Type: model.TradeType(req.Type),
Price: req.Price, Price: req.Price,
Currency: currency.Name, CurrencyCode: currency.Code,
Date: req.Date, Date: req.Date,
} }
database.InsertTrade(db, trade) database.InsertTrade(db, trade)

View File

@@ -2,15 +2,16 @@ package model
type Company struct { type Company struct {
ID int ID int
Name string Symbol string
SharesOutstanding int SharesOutstanding int
Price float64 Price float64
CurrencyID int CurrencyID int
} }
type CompanyInput struct { type CompanyInput struct {
Name string `json:"name"` Symbol string `json:"symbol"`
SharesOutstanding int `json:"shares_outstanding"` SharesOutstanding int `json:"shares_outstanding"`
Price float64 `json:"price"` Price float64 `json:"price"`
CurrencyID int `json:"currency_id"` CurrencyID int `json:"currency_id"`
CurrencyCode string `json:"currency_code"`
} }

View File

@@ -6,11 +6,13 @@ import (
) )
type Position struct { type Position struct {
Company Company CompanyID int
Currency Currency Symbol string
Weight float64 CurrencyCode string
CostBasis float64 CurrencyID int
Shares int Weight float64
CostBasis float64
Shares int
} }
type TradeProduct int type TradeProduct int
@@ -31,13 +33,15 @@ const (
) )
type Trade struct { type Trade struct {
Ticker Company CompanyID int
Shares int Symbol string
Product TradeProduct CurrencyID int
Type TradeType CurrencyCode string
Price float64 Shares int
Currency string Product TradeProduct
Date time.Time Type TradeType
Price float64
Date time.Time
} }
type AddTradeRequest struct { type AddTradeRequest struct {

View File

@@ -1,49 +0,0 @@
package service
import (
"Portifolio/internal/model"
"database/sql"
_ "github.com/mattn/go-sqlite3"
)
func InsertCompany(db *sql.DB, input model.CompanyInput) (int, error) {
res, err := db.Exec(
`INSERT INTO companies (name, shares_outstanding, price, currency_id) VALUES (?, ?, ?, ?)`,
input.Name, input.SharesOutstanding, input.Price, input.CurrencyID,
)
if err != nil {
return 0, err
}
id, err := res.LastInsertId()
return int(id), err
}
func GetAllCompanies(db *sql.DB) ([]model.Company, error) {
rows, err := db.Query(`
SELECT c.id, c.name, c.shares_outstanding, c.price,
cu.id, cu.code, cu.name
FROM companies c
JOIN currencies cu ON c.currency_id = cu.id
ORDER BY c.name
`)
if err != nil {
return nil, err
}
defer rows.Close()
var companies []model.Company
for rows.Next() {
var c model.Company
var cu model.Currency
if err := rows.Scan(
&c.ID, &c.Name, &c.SharesOutstanding, &c.Price,
&cu.ID, &cu.Code, &cu.Name,
); err != nil {
return nil, err
}
c.CurrencyID = cu.ID
companies = append(companies, c)
}
return companies, rows.Err()
}

View File

@@ -1,16 +0,0 @@
package service
import (
"Portifolio/internal/model"
"database/sql"
_ "github.com/mattn/go-sqlite3"
)
func AddCompany(input model.CompanyInput, db *sql.DB) error {
_, err := db.Exec(
`INSERT INTO companies (name, shares_outstanding, price, currency_id) VALUES (?, ?, ?, ?)`,
input.Name, input.SharesOutstanding, input.Price, input.CurrencyID,
)
return err
}

View File

@@ -3,7 +3,6 @@ package shell
import ( import (
"Portifolio/internal/database" "Portifolio/internal/database"
"Portifolio/internal/model" "Portifolio/internal/model"
"Portifolio/internal/service"
"bufio" "bufio"
"database/sql" "database/sql"
"fmt" "fmt"
@@ -16,9 +15,9 @@ import (
func AddCompany(scanner *bufio.Scanner, db *sql.DB) { func AddCompany(scanner *bufio.Scanner, db *sql.DB) {
input := model.CompanyInput{} input := model.CompanyInput{}
fmt.Print(" Name: ") fmt.Print(" symbol: ")
scanner.Scan() scanner.Scan()
input.Name = strings.TrimSpace(scanner.Text()) input.Symbol = strings.TrimSpace(scanner.Text())
fmt.Print(" Shares outstanding: ") fmt.Print(" Shares outstanding: ")
scanner.Scan() scanner.Scan()
@@ -38,24 +37,19 @@ func AddCompany(scanner *bufio.Scanner, db *sql.DB) {
} }
input.Price = price input.Price = price
fmt.Print(" Currency ID: ") fmt.Print(" Currency Code: ")
scanner.Scan() scanner.Scan()
cid, err := strconv.Atoi(strings.TrimSpace(scanner.Text())) input.CurrencyCode = strings.TrimSpace(scanner.Text())
if err != nil {
fmt.Println(" Invalid currency ID.")
return
}
input.CurrencyID = cid
if err := service.AddCompany(input, db); err != nil { if _, err := database.AddCompany(db, input); err != nil {
fmt.Println(" Error:", err) fmt.Println(" Error:", err)
return return
} }
fmt.Printf(" ✓ Company '%s' added.\n", input.Name) fmt.Printf(" ✓ Company '%s' added.\n", input.Symbol)
} }
func ListCompanies(db *sql.DB) { func ListCompanies(db *sql.DB) {
companies, err := service.GetAllCompanies(db) companies, err := database.GetAllCompanies(db)
if err != nil { if err != nil {
fmt.Println(" ✗ Error:", err) fmt.Println(" ✗ Error:", err)
return return
@@ -77,6 +71,6 @@ func ListCompanies(db *sql.DB) {
} }
fmt.Printf(" %-5d %-20s %-10s %-15.2f %d\n", fmt.Printf(" %-5d %-20s %-10s %-15.2f %d\n",
c.ID, c.Name, currency, c.Price, c.SharesOutstanding) c.ID, c.Symbol, currency.Code, c.Price, c.SharesOutstanding)
} }
} }

View File

@@ -30,8 +30,6 @@ func main() {
} }
database.InitDB(db) database.InitDB(db)
database.MigrateAddUniqueToRevenueEntries(db)
database.MigrateTradeCode(db)
fmt.Println("Connected to SQLite database") fmt.Println("Connected to SQLite database")
http.HandleFunc("/health", handlers.HealthHandler(db)) http.HandleFunc("/health", handlers.HealthHandler(db))
@@ -47,7 +45,8 @@ func main() {
http.HandleFunc("GET /company/revenue/categories", handlers.GetCompanyRevenueCategories(db)) http.HandleFunc("GET /company/revenue/categories", handlers.GetCompanyRevenueCategories(db))
// Currency // Currency
http.HandleFunc("GET /currencies", handlers.GetCurrenciesHandler(db)) http.HandleFunc("GET /currency/list", handlers.GetCurrenciesHandler(db))
http.HandleFunc("POST /currency/add", handlers.AddCurrencyHandler(db))
// Revenue // Revenue
http.HandleFunc("POST /add/revenue/entry", handlers.AddRevenueEntryHandler(db)) http.HandleFunc("POST /add/revenue/entry", handlers.AddRevenueEntryHandler(db))