changes to model

This commit is contained in:
samantha42
2026-03-26 08:39:42 +01:00
parent ff7b41e2a8
commit 3f878c1dc0
15 changed files with 119 additions and 195 deletions

View File

@@ -11,9 +11,9 @@ import (
func GetCompanyByID(db *sql.DB, id int) (*model.Company, error) {
var c model.Company
err := db.QueryRow(
`SELECT id, name, shares_outstanding, price, currency_id FROM companies WHERE id = ?`,
`SELECT id, symbol, shares_outstanding, price, currency_id FROM companies WHERE id = ?`,
id,
).Scan(&c.ID, &c.Name, &c.SharesOutstanding, &c.Price, &c.CurrencyID)
).Scan(&c.ID, &c.Symbol, &c.SharesOutstanding, &c.Price, &c.CurrencyID)
if err == sql.ErrNoRows {
return nil, nil
@@ -23,3 +23,49 @@ func GetCompanyByID(db *sql.DB, id int) (*model.Company, error) {
}
return &c, nil
}
func AddCompany(db *sql.DB, input model.CompanyInput) (int, error) {
if input.CurrencyID == 0 {
if input.CurrencyCode != "" {
currency, err := GetCurrencyByCode(db, input.CurrencyCode)
if err != nil {
return 0, fmt.Errorf("could not get currency: %s", err)
}
input.CurrencyID = currency.ID
} else {
return 0, fmt.Errorf("no currency reference")
}
}
res, err := db.Exec(
`INSERT INTO companies (symbol, shares_outstanding, price, currency_id) VALUES (?, ?, ?, ?)`,
input.Symbol, input.SharesOutstanding, input.Price, input.CurrencyID,
)
if err != nil {
return 0, fmt.Errorf("failed to insert: %s", err)
}
id, err := res.LastInsertId()
return int(id), err
}
func GetAllCompanies(db *sql.DB) ([]model.Company, error) {
rows, err := db.Query(`
SELECT id, symbol, shares_outstanding, price, currency_id FROM companies
`)
if err != nil {
return nil, err
}
defer rows.Close()
var companies []model.Company
for rows.Next() {
var c model.Company
if err := rows.Scan(
&c.ID, &c.Symbol, &c.SharesOutstanding, &c.Price, &c.CurrencyID,
); err != nil {
return nil, err
}
companies = append(companies, c)
}
return companies, rows.Err()
}

View File

@@ -20,14 +20,22 @@ func GetCurrencyByID(db *sql.DB, ID int) (model.Currency, error) {
return c, nil
}
/*
CREATE TABLE IF NOT EXISTS currencies (
id INTEGER PRIMARY KEY AUTOINCREMENT,
code TEXT NOT NULL UNIQUE,
name TEXT NOT NULL
);
*/
func GetCurrencyByCode(db *sql.DB, Code string) (model.Currency, error) {
var c model.Currency
err := db.QueryRow(
`SELECT id, code, name, FROM currencies WHERE code = ?`,
`SELECT id, code, name FROM currencies WHERE code = ?`,
Code,
).Scan(&c.ID, &c.Code, &c.Name)
if err == sql.ErrNoRows {
return c, fmt.Errorf("company %d not found", Code)
return c, fmt.Errorf("company %s not found", Code)
}
return c, nil
}

View File

@@ -18,7 +18,7 @@ func InitDB(db *sql.DB) {
CREATE TABLE IF NOT EXISTS trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
company_id INTEGER NOT NULL,
symbol TEXT NOT NULL,
currency_code TEXT NOT NULL,
shares INTEGER NOT NULL,
product INTEGER NOT NULL CHECK(product IN (0, 1, 2, 3)),
@@ -29,16 +29,20 @@ func InitDB(db *sql.DB) {
CREATE TABLE IF NOT EXISTS position (
id INTEGER PRIMARY KEY AUTOINCREMENT,
company_id INTEGER NOT NULL,
company_id INTEGER NOT NULL UNIQUE,
symbol TEXT NOT NULL,
currency_id INTEGER NOT NULL,
currency_code TEXT NOT NULL,
shares INTEGER NOT NULL,
weight REAL NOT NULL,
cost_basis REAL NOT NULL
cost_basis REAL NOT NULL,
FOREIGN KEY (currency_id) REFERENCES currencies(id),
FOREIGN KEY (company_id) REFERENCES companies(id)
);
CREATE TABLE IF NOT EXISTS companies (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
symbol TEXT NOT NULL UNIQUE,
shares_outstanding INTEGER NOT NULL,
price REAL NOT NULL,
currency_id INTEGER NOT NULL,
@@ -84,68 +88,3 @@ func InitDB(db *sql.DB) {
}
fmt.Println("Tables ready")
}
func MigrateAddUniqueToRevenueEntries(db *sql.DB) error {
steps := []string{
// 1. copy existing data into a temp table with the new constraint
`CREATE TABLE IF NOT EXISTS revenue_entries_new (
id INTEGER PRIMARY KEY AUTOINCREMENT,
company_id INTEGER NOT NULL,
currency_id INTEGER NOT NULL,
category_id INTEGER NOT NULL,
period_id INTEGER NOT NULL,
value REAL NOT NULL,
FOREIGN KEY (company_id) REFERENCES companies(id),
FOREIGN KEY (currency_id) REFERENCES currencies(id),
FOREIGN KEY (category_id) REFERENCES category(id),
FOREIGN KEY (period_id) REFERENCES periods(id),
UNIQUE(company_id, category_id, period_id)
)`,
// 2. copy data over
`INSERT OR IGNORE INTO revenue_entries_new
SELECT id, company_id, currency_id, category_id, period_id, value
FROM revenue_entries`,
// 3. drop old table
`DROP TABLE revenue_entries`,
// 4. rename new table
`ALTER TABLE revenue_entries_new RENAME TO revenue_entries`,
}
for _, step := range steps {
if _, err := db.Exec(step); err != nil {
return fmt.Errorf("migration failed: %w", err)
}
}
return nil
}
func MigrateTradeCode(db *sql.DB) error {
step :=
// 1. copy existing data into a temp table with the new constraint
`
ALTER TABLE trades RENAME TO trades_old;
CREATE TABLE IF NOT EXISTS trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
company_id INTEGER NOT NULL,
currency_code TEXT NOT NULL,
shares INTEGER NOT NULL,
product INTEGER NOT NULL CHECK(product IN (0, 1, 2, 3)),
type INTEGER NOT NULL CHECK(type IN (0, 1)),
price REAL NOT NULL,
traded_at DATETIME NOT NULL
);
INSERT INTO trades (id, company_id, currency_code, shares, product, type, price, traded_at)
SELECT id, company_id, '', shares, product, type, price, traded_at
FROM trades_old;
DROP TABLE trades_old;
`
_, err := db.Exec(step)
if err != nil {
return fmt.Errorf("migration failed: %w", err)
}
return nil
}

View File

@@ -9,7 +9,7 @@ import (
)
func GetTrades(db *sql.DB) ([]model.Trade, error) {
rows, err := db.Query("SELECT company_id, currency_code, shares, product, type, price, traded_at FROM trades")
rows, err := db.Query("SELECT company_id, symbol, currency_id, currency_code, shares, product, type, price, traded_at FROM trades")
if err != nil {
return nil, err
}
@@ -17,21 +17,14 @@ func GetTrades(db *sql.DB) ([]model.Trade, error) {
var trades []model.Trade
for rows.Next() {
var tickerInt int
var typeInt int
var t model.Trade
err := rows.Scan(&tickerInt, &t.Currency, &t.Shares, &t.Product, &typeInt, &t.Price, &t.Date)
err := rows.Scan(&t.CompanyID, &t.Symbol, &t.CompanyID, &t.CurrencyCode, &t.Shares, &t.Product, &typeInt, &t.Price, &t.Date)
if err != nil {
return nil, err
}
company, err := GetCompanyByID(db, tickerInt)
if err != nil {
return nil, err
}
t.Ticker = *company
switch typeInt {
case 0:
t.Type = model.TradeType(false)
@@ -50,7 +43,7 @@ func GetTrades(db *sql.DB) ([]model.Trade, error) {
}
func GetPositions(db *sql.DB) ([]model.Position, error) {
rows, err := db.Query("SELECT company_id, shares, weight, CostBases, currency_id")
rows, err := db.Query("SELECT company_id, symbol, shares, weight, CostBasis, currency_id, currency_code from position")
if err != nil {
return nil, err
}
@@ -59,10 +52,12 @@ func GetPositions(db *sql.DB) ([]model.Position, error) {
var positions []model.Position
for rows.Next() {
var t model.Position
err := rows.Scan(&t.Company.ID, &t.Shares, &t.Weight, &t.CostBasis, t.Currency)
err := rows.Scan(&t.CompanyID, &t.Symbol, &t.Shares, &t.Weight, &t.CostBasis, &t.CurrencyID, &t.CurrencyCode)
if err != nil {
return nil, err
}
positions = append(positions, t)
}
if err = rows.Err(); err != nil {
@@ -74,9 +69,11 @@ func GetPositions(db *sql.DB) ([]model.Position, error) {
func InsertTrade(db *sql.DB, trade model.Trade) error {
_, err := db.Exec(
"INSERT INTO trades (company_id, currency_id, shares, product, type, price, traded_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
trade.Ticker.ID,
trade.Currency,
"INSERT INTO trades (company_id, symbol, currency_id, currency_code, shares, product, type, price, traded_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
trade.CompanyID,
trade.Symbol,
trade.CurrencyID,
trade.CurrencyCode,
trade.Shares,
trade.Product,
trade.Type,

View File

@@ -1,8 +1,8 @@
package handlers
import (
"Portifolio/internal/database"
"Portifolio/internal/model"
"Portifolio/internal/service"
"database/sql"
"encoding/json"
"net/http"
@@ -79,11 +79,11 @@ func AddCompanyHandler(db *sql.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var input model.CompanyInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
http.Error(w, "invalid request body", http.StatusBadRequest)
return
}
id, err := service.InsertCompany(db, input)
id, err := database.AddCompany(db, input)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@@ -91,13 +91,13 @@ func AddCompanyHandler(db *sql.DB) http.HandlerFunc {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]any{"status": "created", "id": id})
json.NewEncoder(w).Encode(map[string]int{"id": id})
}
}
func GetCompaniesHandler(db *sql.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
companies, err := service.GetAllCompanies(db)
companies, err := database.GetAllCompanies(db)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return

View File

@@ -40,13 +40,13 @@ func AddTradeHandler(db *sql.DB) http.HandlerFunc {
}
trade := model.Trade{
Ticker: *company,
Shares: req.Shares,
Product: model.TradeProduct(req.Product),
Type: model.TradeType(req.Type),
Price: req.Price,
Currency: currency.Name,
Date: req.Date,
Symbol: company.Symbol,
Shares: req.Shares,
Product: model.TradeProduct(req.Product),
Type: model.TradeType(req.Type),
Price: req.Price,
CurrencyCode: currency.Code,
Date: req.Date,
}
database.InsertTrade(db, trade)

View File

@@ -2,15 +2,16 @@ package model
type Company struct {
ID int
Name string
Symbol string
SharesOutstanding int
Price float64
CurrencyID int
}
type CompanyInput struct {
Name string `json:"name"`
Symbol string `json:"symbol"`
SharesOutstanding int `json:"shares_outstanding"`
Price float64 `json:"price"`
CurrencyID int `json:"currency_id"`
CurrencyCode string `json:"currency_code"`
}

View File

@@ -6,11 +6,13 @@ import (
)
type Position struct {
Company Company
Currency Currency
Weight float64
CostBasis float64
Shares int
CompanyID int
Symbol string
CurrencyCode string
CurrencyID int
Weight float64
CostBasis float64
Shares int
}
type TradeProduct int
@@ -31,13 +33,15 @@ const (
)
type Trade struct {
Ticker Company
Shares int
Product TradeProduct
Type TradeType
Price float64
Currency string
Date time.Time
CompanyID int
Symbol string
CurrencyID int
CurrencyCode string
Shares int
Product TradeProduct
Type TradeType
Price float64
Date time.Time
}
type AddTradeRequest struct {

View File

@@ -1,49 +0,0 @@
package service
import (
"Portifolio/internal/model"
"database/sql"
_ "github.com/mattn/go-sqlite3"
)
func InsertCompany(db *sql.DB, input model.CompanyInput) (int, error) {
res, err := db.Exec(
`INSERT INTO companies (name, shares_outstanding, price, currency_id) VALUES (?, ?, ?, ?)`,
input.Name, input.SharesOutstanding, input.Price, input.CurrencyID,
)
if err != nil {
return 0, err
}
id, err := res.LastInsertId()
return int(id), err
}
func GetAllCompanies(db *sql.DB) ([]model.Company, error) {
rows, err := db.Query(`
SELECT c.id, c.name, c.shares_outstanding, c.price,
cu.id, cu.code, cu.name
FROM companies c
JOIN currencies cu ON c.currency_id = cu.id
ORDER BY c.name
`)
if err != nil {
return nil, err
}
defer rows.Close()
var companies []model.Company
for rows.Next() {
var c model.Company
var cu model.Currency
if err := rows.Scan(
&c.ID, &c.Name, &c.SharesOutstanding, &c.Price,
&cu.ID, &cu.Code, &cu.Name,
); err != nil {
return nil, err
}
c.CurrencyID = cu.ID
companies = append(companies, c)
}
return companies, rows.Err()
}

View File

@@ -1,16 +0,0 @@
package service
import (
"Portifolio/internal/model"
"database/sql"
_ "github.com/mattn/go-sqlite3"
)
func AddCompany(input model.CompanyInput, db *sql.DB) error {
_, err := db.Exec(
`INSERT INTO companies (name, shares_outstanding, price, currency_id) VALUES (?, ?, ?, ?)`,
input.Name, input.SharesOutstanding, input.Price, input.CurrencyID,
)
return err
}

View File

@@ -3,7 +3,6 @@ package shell
import (
"Portifolio/internal/database"
"Portifolio/internal/model"
"Portifolio/internal/service"
"bufio"
"database/sql"
"fmt"
@@ -16,9 +15,9 @@ import (
func AddCompany(scanner *bufio.Scanner, db *sql.DB) {
input := model.CompanyInput{}
fmt.Print(" Name: ")
fmt.Print(" symbol: ")
scanner.Scan()
input.Name = strings.TrimSpace(scanner.Text())
input.Symbol = strings.TrimSpace(scanner.Text())
fmt.Print(" Shares outstanding: ")
scanner.Scan()
@@ -38,24 +37,19 @@ func AddCompany(scanner *bufio.Scanner, db *sql.DB) {
}
input.Price = price
fmt.Print(" Currency ID: ")
fmt.Print(" Currency Code: ")
scanner.Scan()
cid, err := strconv.Atoi(strings.TrimSpace(scanner.Text()))
if err != nil {
fmt.Println(" Invalid currency ID.")
return
}
input.CurrencyID = cid
input.CurrencyCode = strings.TrimSpace(scanner.Text())
if err := service.AddCompany(input, db); err != nil {
if _, err := database.AddCompany(db, input); err != nil {
fmt.Println(" Error:", err)
return
}
fmt.Printf(" ✓ Company '%s' added.\n", input.Name)
fmt.Printf(" ✓ Company '%s' added.\n", input.Symbol)
}
func ListCompanies(db *sql.DB) {
companies, err := service.GetAllCompanies(db)
companies, err := database.GetAllCompanies(db)
if err != nil {
fmt.Println(" ✗ Error:", err)
return
@@ -77,6 +71,6 @@ func ListCompanies(db *sql.DB) {
}
fmt.Printf(" %-5d %-20s %-10s %-15.2f %d\n",
c.ID, c.Name, currency, c.Price, c.SharesOutstanding)
c.ID, c.Symbol, currency.Code, c.Price, c.SharesOutstanding)
}
}