diff --git a/Portifolio b/Portifolio index 27919fc..bccb166 100755 Binary files a/Portifolio and b/Portifolio differ diff --git a/app.db b/app.db index 23f7672..4d8c3da 100644 Binary files a/app.db and b/app.db differ diff --git a/internal/database/category.go b/internal/database/category.go new file mode 100644 index 0000000..323ecc4 --- /dev/null +++ b/internal/database/category.go @@ -0,0 +1,54 @@ +package database + +import ( + "Portifolio/internal/model" + "database/sql" + "fmt" +) + +func InsertCategory(db *sql.DB, rc model.RevenueCategory) error { + if err := rc.Validate(false); err != nil { + return fmt.Errorf("failed to insert: %w", err) + } + + company, err := GetCompanyByID(db, rc.CompanyID) + if err != nil { + return fmt.Errorf("failed to check company: %w", err) + } + if company == nil { + return fmt.Errorf("company %d not found", rc.CompanyID) + } + + if rc.ParentID != nil && *rc.ParentID != 0 { + parent, err := GetCategoryByID(db, rc.CompanyID, *rc.ParentID) + if err != nil { + return fmt.Errorf("failed to check parent category: %w", err) + } + if parent == nil { + return fmt.Errorf("parent category %d not found", *rc.ParentID) + } + } + + var parentID sql.NullInt64 + if rc.ParentID != nil && *rc.ParentID != 0 { + parentID = sql.NullInt64{Int64: int64(*rc.ParentID), Valid: true} + } + + _, err = db.Exec( + `INSERT INTO category (company_id, parent_id, name) VALUES (?, ?, ?) + ON CONFLICT(company_id, name) DO UPDATE SET parent_id=excluded.parent_id`, + rc.CompanyID, parentID, rc.Name, + ) + if err != nil { + return fmt.Errorf("upsert category: %w", err) + } + + err = db.QueryRow( + `SELECT id FROM category WHERE company_id = ? AND name = ?`, + rc.CompanyID, rc.Name, + ).Scan(&rc.ID) + if err != nil { + return fmt.Errorf("select category id: %w", err) + } + return nil +} diff --git a/internal/database/company.go b/internal/database/company.go new file mode 100644 index 0000000..6be510c --- /dev/null +++ b/internal/database/company.go @@ -0,0 +1,25 @@ +package database + +import ( + "Portifolio/internal/model" + "database/sql" + "fmt" + + _ "github.com/mattn/go-sqlite3" +) + +func GetCompanyByID(db *sql.DB, id int) (*model.Company, error) { + var c model.Company + err := db.QueryRow( + `SELECT id, name, shares_outstanding, price, currency_id FROM companies WHERE id = ?`, + id, + ).Scan(&c.ID, &c.Name, &c.SharesOutstanding, &c.Price, &c.CurrencyID) + + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("query company: %w", err) + } + return &c, nil +} diff --git a/internal/database/currency.go b/internal/database/currency.go new file mode 100644 index 0000000..e5d9fe0 --- /dev/null +++ b/internal/database/currency.go @@ -0,0 +1,21 @@ +package database + +import ( + "Portifolio/internal/model" + "database/sql" + "fmt" + + _ "github.com/mattn/go-sqlite3" +) + +func GetCurrencyByID(db *sql.DB, ID int) (model.Currency, error) { + var c model.Currency + err := db.QueryRow( + `SELECT id, code, name, FROM currencies WHERE id = ?`, + ID, + ).Scan(&c.ID, &c.Code, &c.Name) + if err == sql.ErrNoRows { + return c, fmt.Errorf("company %d not found", ID) + } + return c, nil +} diff --git a/internal/database/revenue.go b/internal/database/revenue.go index fd59555..f76da0d 100644 --- a/internal/database/revenue.go +++ b/internal/database/revenue.go @@ -9,7 +9,23 @@ import ( _ "github.com/mattn/go-sqlite3" ) +func GetCategoryByID(db *sql.DB, companyID int, ID int) (*model.RevenueCategory, error) { + var rc model.RevenueCategory + err := db.QueryRow( + `SELECT id, company_id, parent_id, name FROM category WHERE company_id = ? AND id = ?`, + companyID, ID, + ).Scan(&rc.ID, &rc.CompanyID, &rc.ParentID, &rc.Name) + if err == sql.ErrNoRows { + return &rc, fmt.Errorf("category %q not found for company %d", ID, companyID) + } + if err != nil { + return &rc, fmt.Errorf("get category by name: %w", err) + } + return &rc, nil +} + func GetCategoryByName(db *sql.DB, companyID int, name string) (model.RevenueCategory, error) { + var rc model.RevenueCategory err := db.QueryRow( `SELECT id, company_id, parent_id, name FROM category WHERE company_id = ? AND name = ?`, diff --git a/internal/model/company.go b/internal/model/company.go index 74ee95c..43d6bb4 100644 --- a/internal/model/company.go +++ b/internal/model/company.go @@ -6,7 +6,6 @@ type Company struct { SharesOutstanding int Price float64 CurrencyID int - Currency *Currency // populated on joins } type CompanyInput struct { diff --git a/internal/model/periode.go b/internal/model/periode.go index c909907..6606900 100644 --- a/internal/model/periode.go +++ b/internal/model/periode.go @@ -2,10 +2,65 @@ package model import ( + "database/sql" "fmt" "time" ) +// PeriodType defines the granularity of the revenue entry +type PeriodType string + +const ( + PeriodQuarter PeriodType = "Q" + PeriodHalfYear PeriodType = "H" + PeriodYear PeriodType = "Y" +) + +// Period holds the actual time range for a revenue entry +type Period struct { + ID int + Type PeriodType + Year int + Index int // Q1=1 Q2=2 Q3=3 Q4=4 | H1=1 H2=2 | FY=1 + Start time.Time + End time.Time +} + +func (p *Period) Validate(checkID bool) error { + if p.ID == 0 && checkID { + return fmt.Errorf("No ID Set") + } else if p.Type == PeriodQuarter && (p.Index > 4 || p.Index < 1) { + return fmt.Errorf("Not Valid Quarter index") + } else if p.Type == PeriodHalfYear && (p.Index > 2 || p.Index < 1) { + return fmt.Errorf("Not Valid HalfYear index") + } else if p.Type == PeriodHalfYear && (p.Index != 1 || p.Year < 1) { + return fmt.Errorf("Not Valid Year index") + } + return nil +} + +func (p *Period) Insert(db *sql.DB) error { + _, err := db.Exec( + `INSERT INTO periods (type, year, idx, start_date, end_date) VALUES (?, ?, ?, ?, ?) + ON CONFLICT(type, year, idx) DO UPDATE SET start_date=excluded.start_date, end_date=excluded.end_date`, + string(p.Type), p.Year, p.Index, p.Start.Format("2006-01-02"), p.End.Format("2006-01-02"), + ) + if err != nil { + return fmt.Errorf("upsert period: %w", err) + } + + var id int + err = db.QueryRow( + `SELECT id FROM periods WHERE type = ? AND year = ? AND idx = ?`, + string(p.Type), p.Year, p.Index, + ).Scan(&id) + if err != nil { + return fmt.Errorf("select period: %w", err) + } + p.ID = id + return nil +} + func QuarterPeriod(year, q int) Period { months := map[int][2]int{ 1: {1, 3}, 2: {4, 6}, 3: {7, 9}, 4: {10, 12}, diff --git a/internal/model/revenue.go b/internal/model/revenue.go index 70cc388..220288c 100644 --- a/internal/model/revenue.go +++ b/internal/model/revenue.go @@ -1,52 +1,9 @@ package model import ( - "database/sql" "fmt" - "time" ) -// PeriodType defines the granularity of the revenue entry -type PeriodType string - -const ( - PeriodQuarter PeriodType = "Q" - PeriodHalfYear PeriodType = "H" - PeriodYear PeriodType = "Y" -) - -// Period holds the actual time range for a revenue entry -type Period struct { - ID int - Type PeriodType - Year int - Index int // Q1=1 Q2=2 Q3=3 Q4=4 | H1=1 H2=2 | FY=1 - Start time.Time - End time.Time -} - -func (p *Period) Insert(db *sql.DB) error { - _, err := db.Exec( - `INSERT INTO periods (type, year, idx, start_date, end_date) VALUES (?, ?, ?, ?, ?) - ON CONFLICT(type, year, idx) DO UPDATE SET start_date=excluded.start_date, end_date=excluded.end_date`, - string(p.Type), p.Year, p.Index, p.Start.Format("2006-01-02"), p.End.Format("2006-01-02"), - ) - if err != nil { - return fmt.Errorf("upsert period: %w", err) - } - - var id int - err = db.QueryRow( - `SELECT id FROM periods WHERE type = ? AND year = ? AND idx = ?`, - string(p.Type), p.Year, p.Index, - ).Scan(&id) - if err != nil { - return fmt.Errorf("select period: %w", err) - } - p.ID = id - return nil -} - type RevenueCategory struct { ID int CompanyID int @@ -54,8 +11,8 @@ type RevenueCategory struct { Name string // e.g. "product", "location", "segment" } -func (rc *RevenueCategory) Validate() error { - if rc.ID == 0 { +func (rc *RevenueCategory) Validate(checkID bool) error { + if rc.ID == 0 && checkID { return fmt.Errorf("No ID Set") } else if rc.Name == "" { return fmt.Errorf("No Name found") @@ -65,31 +22,6 @@ func (rc *RevenueCategory) Validate() error { return nil } -func (rc *RevenueCategory) Insert(db *sql.DB) error { - if err := rc.Validate(); err != nil { - return fmt.Errorf("failed to insert: %w", err) - } - - _, err := db.Exec( - `INSERT INTO category (company_id, parent_id, name) VALUES (?, ?, ?) - ON CONFLICT(company_id, name) DO UPDATE SET parent_id=excluded.parent_id`, - rc.CompanyID, rc.ParentID, rc.Name, - ) - if err != nil { - return fmt.Errorf("upsert category: %w", err) - } - - err = db.QueryRow( - `SELECT id FROM category WHERE company_id = ? AND name = ?`, - rc.CompanyID, rc.Name, - ).Scan(&rc.ID) - if err != nil { - return fmt.Errorf("select category id: %w", err) - } - - return nil -} - // Revenue is a single line in a financial report type Revenue struct { ID int diff --git a/internal/service/company.go b/internal/service/company.go index e3c50a1..ed3ef16 100644 --- a/internal/service/company.go +++ b/internal/service/company.go @@ -43,7 +43,6 @@ func GetAllCompanies(db *sql.DB) ([]model.Company, error) { return nil, err } c.CurrencyID = cu.ID - c.Currency = &cu companies = append(companies, c) } return companies, rows.Err() diff --git a/internal/service/revenue.go b/internal/service/revenue.go index a103d87..b995b91 100644 --- a/internal/service/revenue.go +++ b/internal/service/revenue.go @@ -10,9 +10,19 @@ import ( _ "github.com/mattn/go-sqlite3" ) -func InsertRevenue(db *sql.DB, companyID, currencyID int, categoryName string, parentID *int, value float64, period model.Period) error { +func InsertRevenue(db *sql.DB, companyID int, currencyID int, categoryName string, parentID *int, value float64, period model.Period) error { + _, err := database.GetCompanyByID(db, companyID) + if err != nil { + return err + } - period, err := database.GetPeriodByID(db, period.ID) + _, err = database.GetCurrencyByID(db, currencyID) + if err != nil { + return err + } + + // checking if period is in db, in case not will insert + _, err = database.GetPeriodByID(db, period.ID) if err != nil { err = period.Insert(db) if err != nil { @@ -27,7 +37,7 @@ func InsertRevenue(db *sql.DB, companyID, currencyID int, categoryName string, p ParentID: parentID, Name: categoryName, } - err := category.Insert(db) + err := database.InsertCategory(db, category) if err != nil { return err } diff --git a/internal/shell/company.go b/internal/shell/company.go index 9a5d429..a3c1881 100644 --- a/internal/shell/company.go +++ b/internal/shell/company.go @@ -1,6 +1,7 @@ package shell import ( + "Portifolio/internal/database" "Portifolio/internal/model" "Portifolio/internal/service" "bufio" @@ -68,10 +69,13 @@ func ListCompanies(db *sql.DB) { fmt.Printf("\n %-5s %-20s %-10s %-15s %s\n", "ID", "NAME", "CURRENCY", "PRICE", "SHARES") fmt.Println(" " + strings.Repeat("-", 60)) for _, c := range companies { - currency := "N/A" - if c.Currency != nil { - currency = c.Currency.Code + + currency, err := database.GetCurrencyByID(db, c.CurrencyID) + if err != nil { + fmt.Println("No currency by id.") + return } + fmt.Printf(" %-5d %-20s %-10s %-15.2f %d\n", c.ID, c.Name, currency, c.Price, c.SharesOutstanding) } diff --git a/internal/shell/revenue.go b/internal/shell/revenue.go index b50649e..3a05ae9 100644 --- a/internal/shell/revenue.go +++ b/internal/shell/revenue.go @@ -1,6 +1,7 @@ package shell import ( + "Portifolio/internal/database" "Portifolio/internal/model" "Portifolio/internal/service" "bufio" @@ -67,6 +68,12 @@ func AddRevenue(scanner *bufio.Scanner, db *sql.DB) { fmt.Println(" ✗", err) return } + // checking if company exits + _, err = database.GetCompanyByID(db, companyID) + if err != nil { + fmt.Println("No company by that id:", err) + return + } currencyID, err := promptInt(scanner, " Currency ID: ") if err != nil { fmt.Println(" ✗", err)