From cf4b8e01194ffa70be9b43e89daa2cb16a414edb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20Forc=C3=A9n=20Mu=C3=B1oz?= Date: Wed, 16 Apr 2025 23:24:23 +0200 Subject: [PATCH] Completed db rewrite --- apis/base.go | 6 +- apis/expenses.go | 23 +++---- apis/extract.go | 14 ++-- apis/plans.go | 18 +++-- apis/polls.go | 27 ++++++-- apis/users.go | 27 +++++--- base.sql | 47 +++++++++++++ core/debts.go | 12 ++++ core/expenses.go | 159 ++++++++++++++++++++++++++++++-------------- core/models.go | 56 ++++++---------- core/plans.go | 118 ++++++++++++++++---------------- core/polls.go | 64 +++++++++++------- core/sqlexecutor.go | 9 +++ core/users.go | 93 +++++++++++++++++++------- core/votes.go | 23 +++++++ db.go | 44 +++++------- go.mod | 8 +-- go.sum | 2 + planner.go | 7 +- 19 files changed, 481 insertions(+), 276 deletions(-) create mode 100644 base.sql create mode 100644 core/debts.go create mode 100644 core/sqlexecutor.go create mode 100644 core/votes.go diff --git a/apis/base.go b/apis/base.go index a60e9e6..a577c85 100644 --- a/apis/base.go +++ b/apis/base.go @@ -1,16 +1,16 @@ package apis import ( + "database/sql" "errors" "net/http" "github.com/gin-gonic/gin" - "gorm.io/gorm" ) -var db *gorm.DB +var db *sql.DB -func BindAPIs(r *gin.Engine, cfg_db *gorm.DB) error { +func BindAPIs(r *gin.Engine, cfg_db *sql.DB) error { if cfg_db == nil { return errors.New("Database is null") } diff --git a/apis/expenses.go b/apis/expenses.go index c38429f..68a7685 100644 --- a/apis/expenses.go +++ b/apis/expenses.go @@ -49,7 +49,7 @@ func createExpense(c *gin.Context) { if exp.Type == "user" { compare = func(m core.Member) bool { - return m.UserID == exp.Payer + return *m.UserId == exp.Payer } } else if exp.Type == "non-user" { compare = func(m core.Member) bool { @@ -61,7 +61,7 @@ func createExpense(c *gin.Context) { } idx := slices.IndexFunc(members, compare) if idx != -1 { - expense, err := core.CreateExpense(db, plan, members[idx], exp.Amount) + expense, err := core.ExpenseCreate(db, *plan, members[idx], exp.Amount) if err != nil { c.Status(http.StatusInternalServerError) return @@ -89,7 +89,7 @@ func listExpenses(c *gin.Context) { c.Status(http.StatusBadRequest) return } - expenses, err := core.ListExpenses(db, plan) + expenses, err := core.ExpensesList(db, *plan) if err != nil { c.Status(http.StatusInternalServerError) return @@ -109,12 +109,12 @@ func getExpense(c *gin.Context) { c.Status(http.StatusBadRequest) return } - expense, err := core.GetExpense(db, uint(expense_id)) + expense, err := core.ExpensesGet(db, uint(expense_id)) if err != nil { c.Status(http.StatusBadRequest) return } - _, err = u.GetPlan(db, expense.PlanID) + _, err = u.GetPlan(db, expense.PlanId) if err != nil { c.Status(http.StatusBadRequest) return @@ -135,12 +135,12 @@ func deleteExpense(c *gin.Context) { c.Status(http.StatusBadRequest) return } - expense, err := core.GetExpense(db, uint(expense_id)) + expense, err := core.ExpensesGet(db, uint(expense_id)) if err != nil { c.Status(http.StatusBadRequest) return } - _, err = u.GetPlan(db, expense.PlanID) + _, err = u.GetPlan(db, expense.PlanId) if err != nil { c.Status(http.StatusBadRequest) return @@ -166,12 +166,12 @@ func getExpenseDebts(c *gin.Context) { c.Status(http.StatusBadRequest) return } - expense, err := core.GetExpense(db, uint(expense_id)) + expense, err := core.ExpensesGet(db, uint(expense_id)) if err != nil { c.Status(http.StatusBadRequest) return } - _, err = u.GetPlan(db, expense.PlanID) + _, err = u.GetPlan(db, expense.PlanId) if err != nil { c.Status(http.StatusBadRequest) return @@ -198,17 +198,16 @@ func setExpenseDebts(c *gin.Context) { c.Status(http.StatusBadRequest) return } - expense, err := core.GetExpense(db, uint(expense_id)) + expense, err := core.ExpensesGet(db, uint(expense_id)) if err != nil { c.Status(http.StatusBadRequest) return } - plan, err := u.GetPlan(db, expense.PlanID) + _, err = u.GetPlan(db, expense.PlanId) if err != nil { c.Status(http.StatusBadRequest) return } - expense.Plan = plan var debtSpec []core.DebtSpec if err := c.ShouldBind(&debtSpec); err != nil { diff --git a/apis/extract.go b/apis/extract.go index 7ed4f7d..0a9cd86 100644 --- a/apis/extract.go +++ b/apis/extract.go @@ -1,24 +1,22 @@ package apis import ( - "github.com/gin-gonic/gin" - "gorm.io/gorm" + "database/sql" "net/http" "planner/core" + + "github.com/gin-gonic/gin" ) -func extractUser(orm *gorm.DB, c *gin.Context) *core.User { +func extractUser(orm *sql.DB, c *gin.Context) *core.User { username, _, ok := c.Request.BasicAuth() if !ok { c.Status(http.StatusUnauthorized) return nil } - u := core.User{ - Username: username, - } - result := orm.Take(&u) - if result.Error != nil { + u, err := core.UserGet(db, username) + if err != nil { c.String(http.StatusNotFound, "Unable to find user "+username) return nil } diff --git a/apis/plans.go b/apis/plans.go index 5817034..7430277 100644 --- a/apis/plans.go +++ b/apis/plans.go @@ -8,6 +8,7 @@ import ( "github.com/gin-gonic/gin" + "planner/core" . "planner/core" ) @@ -22,7 +23,7 @@ func createPlan(c *gin.Context) { } c.Bind(&plan_req) - _, err := CreatePlan(db, u, plan_req.Name) + _, err := PlanCreate(db, u, plan_req.Name) if err != nil { c.JSON(http.StatusInternalServerError, err) @@ -144,7 +145,7 @@ func joinPlan(c *gin.Context) { c.Status(http.StatusBadRequest) return } - plan, err := GetPlan(db, uint(plan_id)) + plan, err := PlanGet(db, uint(plan_id)) if err != nil || plan == nil { c.Status(http.StatusInternalServerError) return @@ -171,7 +172,7 @@ func joinPlan(c *gin.Context) { c.String(http.StatusConflict, "User already a member") return } - plan.AddMember(db, &Member{UserID: user.Username}) + plan.AddMember(db, &Member{UserId: &user.Username}) c.Status(http.StatusOK) } @@ -212,10 +213,10 @@ func createPlanPoll(c *gin.Context) { } poll := Poll{ - PlanID: plan.ID, + PlanId: plan.Id, Options: poll_opts.Options, } - db.Create(&poll) + poll.Create(db) c.JSON(http.StatusCreated, poll) } @@ -236,8 +237,11 @@ func listPlanPolls(c *gin.Context) { return } - var polls []Poll - db.Where("plan_id = ?", params.Id).Find(&polls) + polls, err := core.PollsList(db, int(params.Id)) + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } c.JSON(http.StatusOK, polls) } diff --git a/apis/polls.go b/apis/polls.go index 93d43d6..fa91666 100644 --- a/apis/polls.go +++ b/apis/polls.go @@ -25,7 +25,7 @@ func getPoll(c *gin.Context) { } fmt.Println(params) - poll, _ := core.GetPoll(db, *user, params.PollId) + poll, _ := core.PollGet(db, *user, params.PollId) c.JSON(http.StatusOK, poll) } @@ -45,9 +45,12 @@ func getPollVotes(c *gin.Context) { return } - var votes []core.Vote - db.Where("poll_id = ?", params.PollId).Find(&votes) - c.JSON(http.StatusOK, &votes) + votes, err := core.VotesList(db, int(params.PollId)) + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + c.JSON(http.StatusOK, votes) } func pollVote(c *gin.Context) { @@ -66,7 +69,7 @@ func pollVote(c *gin.Context) { return } - poll, err := core.GetPoll(db, *user, path_params.PollId) + poll, err := core.PollGet(db, *user, path_params.PollId) if err != nil { c.String(http.StatusInternalServerError, err.Error()) return @@ -77,7 +80,19 @@ func pollVote(c *gin.Context) { } c.Bind(&vote_params) - if err := poll.SetVote(db, *user, vote_params.Vote); err != nil { + plan, err := core.PlanGet(db, poll.PlanId) + if err != nil { + c.String(http.StatusInternalServerError, "Unable to find plan: "+err.Error()) + return + } + + member, err := user.GetMemberFromPlan(db, *plan) + if err != nil { + c.String(http.StatusInternalServerError, "Unable to find member: "+err.Error()) + return + } + + if err := poll.SetVote(db, *member, vote_params.Vote); err != nil { c.String(http.StatusBadRequest, err.Error()) } diff --git a/apis/users.go b/apis/users.go index 16ad82a..3fd802c 100644 --- a/apis/users.go +++ b/apis/users.go @@ -6,7 +6,7 @@ import ( "github.com/gin-gonic/gin" - . "planner/core" + "planner/core" ) func getUserByName(c *gin.Context) { @@ -14,19 +14,23 @@ func getUserByName(c *gin.Context) { Name string `fdb:"name"` } if c.ShouldBind(&q) == nil { - user := User{ - Username: q.Name, + user, err := core.UserGet(db, q.Name) + if err != nil { + c.String(http.StatusInternalServerError, "Unable to get user: "+err.Error()) + return } - db.Take(&user) - fmt.Println(user) c.JSON(http.StatusOK, user) } } func createUser(c *gin.Context) { - var u User + var u core.User if c.ShouldBind(&u) == nil { - db.Create(&u) + err := u.Create(db) + if err != nil { + c.String(http.StatusInternalServerError, "Unable to create user: "+err.Error()) + return + } c.Status(http.StatusCreated) } else { fmt.Print("Could not bind model") @@ -43,12 +47,13 @@ func login(c *gin.Context) { if q.Username == "" { c.String(http.StatusBadRequest, "Login data is null") } else { - user := User{ - Username: q.Username, + user, err := core.UserGet(db, q.Username) + if err != nil { + c.String(http.StatusInternalServerError, "Unable to get user: "+err.Error()) + return } - db.Take(&user) if user.Password == q.Password { - c.JSON(http.StatusOK, map[string]string{"username": user.Username}) + c.JSON(http.StatusOK, gin.H{"username": user.Username}) } else { c.Status(http.StatusForbidden) } diff --git a/base.sql b/base.sql new file mode 100644 index 0000000..ec3cfb8 --- /dev/null +++ b/base.sql @@ -0,0 +1,47 @@ +CREATE TABLE users ( + username TEXT PRIMARY KEY, + password TEXT NOT NULL +); + +CREATE TABLE plans ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + owner TEXT REFERENCES users(username) NOT NULL, + description TEXT DEFAULT '', + join_code TEXT NOT NULL +); + +CREATE TABLE members ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + plan_id INTEGER REFERENCES plans(id) NOT NULL, + name TEXT NOT NULL, + user_id TEXT REFERENCES users(username) +); + +CREATE TABLE polls ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + plan_id INTEGER REFERENCES plans(id) NOT NULL, + options TEXT DEFAULT '' +); + +CREATE TABLE votes ( + poll_id INTEGER REFERENCES polls(id), + member_id INTEGER REFERENCES members(id), + value text, + PRIMARY KEY (poll_id,member_id) +); + +CREATE TABLE expenses ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + plan_id INTEGER REFERENCES plans(id) NOT NULL, + payer_id INTEGER REFERENCES members(id) NOT NULL, + amount DECIMAL NOT NULL +); + +CREATE TABLE debts ( + expense_id INTEGER REFERENCES expenses(id) NOT NULL, + debtor_id INTEGER REFERENCES members(id) NOT NULL, + amount DECIMAL NOT NULL, + paid DECIMAL, + PRIMARY KEY (expense_id,debtor_id) +); diff --git a/core/debts.go b/core/debts.go new file mode 100644 index 0000000..5fcbc8b --- /dev/null +++ b/core/debts.go @@ -0,0 +1,12 @@ +package core + +func (d *Debt) Create(db SqlExecutor) error { + _, err := db.Exec( + "INSERT INTO debts(expense_id, debtor_id, amount, paid) VALUES (?,?,?,?)", + d.ExpenseId, + d.DebtorId, + d.Amount, + d.Paid, + ) + return err +} diff --git a/core/expenses.go b/core/expenses.go index 31dbbb8..fbdae05 100644 --- a/core/expenses.go +++ b/core/expenses.go @@ -1,11 +1,12 @@ package core import ( + "database/sql" "encoding/json" "errors" + "fmt" "github.com/shopspring/decimal" - "gorm.io/gorm" ) type DebtType int @@ -44,27 +45,42 @@ type DebtSpec struct { DebtType DebtType `json:"debt_type"` } -func CreateExpense(db *gorm.DB, plan Plan, payer Member, amount decimal.Decimal) (*Expense, error) { +func ExpenseCreate(db *sql.DB, plan Plan, payer Member, amount decimal.Decimal) (*Expense, error) { expense := new(Expense) - expense.Plan = plan - expense.Payer = payer + expense.PlanId = plan.Id + expense.PayerId = payer.Id expense.Amount = amount - err := db.Create(&expense).Error - if err != nil { + row := db.QueryRow( + "INSERT INTO expenses(plan_id,payer_id,amount) VALUES (?,?,?) RETURNING id", + expense.PlanId, + expense.PayerId, + expense.Amount, + ) + if err := row.Scan(&expense.Id); err != nil { return nil, err } return expense, nil } -func (e *Expense) GetDebt(db *gorm.DB) ([]Debt, error) { +func (e *Expense) GetDebt(db *sql.DB) ([]Debt, error) { var debts []Debt - if err := db.Model(e).Association("Debts").Find(&debts); err != nil { - return []Debt{}, err + rows, err := db.Query("SELECT expense_id,debtor_id,amount,paid FROM debts WHERE expense_id=?") + if err != nil { + return debts, err + } + defer rows.Close() + for rows.Next() { + var d Debt + err = rows.Scan(&d.ExpenseId, &d.DebtorId, &d.Amount, &d.Paid) + if err != nil { + return []Debt{}, err + } + debts = append(debts, d) } return debts, nil } -func (e *Expense) SetDebt(db *gorm.DB, debtSpec []DebtSpec) error { +func (e *Expense) SetDebt(db *sql.DB, debtSpec []DebtSpec) error { abs_paid := decimal.Decimal{} debts := make([]Debt, 0) var prop_payers int64 = 0 @@ -75,8 +91,8 @@ func (e *Expense) SetDebt(db *gorm.DB, debtSpec []DebtSpec) error { debts = append(debts, Debt{ Amount: debt.Amount, Paid: decimal.Decimal{}, - ExpenseID: e.ID, - DebtorID: debt.Member.ID, + ExpenseId: e.Id, + DebtorId: debt.Member.Id, }) abs_paid.Add(debt.Amount) } @@ -86,52 +102,95 @@ func (e *Expense) SetDebt(db *gorm.DB, debtSpec []DebtSpec) error { } prop_debt := e.Amount.Sub(abs_paid).DivRound(decimal.NewFromInt(prop_payers), 2) - return db.Transaction(func(tx *gorm.DB) error { - for _, debt := range debts { - if err := db.Create(&debt).Error; err != nil { - return err - } - } - - for _, debt := range debtSpec { - if debt.DebtType != ProportionalDebt { - continue - } - if err := db.Create(&Debt{ - Amount: prop_debt, - Paid: decimal.Decimal{}, - ExpenseID: e.ID, - DebtorID: debt.Member.ID, - }).Error; err != nil { - return err - } - } - return nil - }) -} - -func (e *Expense) Delete(db *gorm.DB) error { - return db.Transaction(func(tx *gorm.DB) error { - if err := db.Model(e).Association("Debts").Delete(&[]Debt{}); err != nil { + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + for _, debt := range debts { + if err := debt.Create(tx); err != nil { return err } - return db.Delete(e, e).Error - }) + } + + for _, debt := range debtSpec { + if debt.DebtType != ProportionalDebt { + continue + } + debtObj := Debt{ + Amount: prop_debt, + Paid: decimal.Decimal{}, + ExpenseId: e.Id, + DebtorId: debt.Member.Id, + } + if err := debtObj.Create(tx); err != nil { + return err + } + } + return tx.Commit() } -func ListExpenses(db *gorm.DB, plan Plan) ([]Expense, error) { +func (e *Expense) Delete(db *sql.DB) error { + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + _, err = tx.Exec("DELETE FROM debts WHERE expense_id=?", e.Id) + if err != nil { + return fmt.Errorf("Unable to delete debts for expense %d: %w", e.Id, err) + } + _, err = tx.Exec("DELETE FROM expenses WHERE id=?", e.Id) + if err != nil { + return fmt.Errorf("Unable to delete expense %d: %w", e.Id, err) + } + err = tx.Commit() + if err != nil { + return fmt.Errorf( + "Unable to commit transaction when deleting expense %d: %w", + e.Id, + err, + ) + } + return nil +} + +func ExpensesList(db *sql.DB, plan Plan) ([]Expense, error) { var expenses []Expense - if err := db.Where("plan_id = ?", plan.ID).Find(&expenses).Error; err != nil { - return []Expense{}, err + rows, err := db.Query( + "SELECT id,plan_id,payer_id,amount FROM expenses WHERE plan_id=$1", + plan.Id, + ) + if err != nil { + return expenses, fmt.Errorf( + "Unable to query for expenses in plan %d: %w", + plan.Id, + err, + ) + } + defer rows.Close() + for rows.Next() { + e := Expense{} + err = rows.Scan(&e.Id, e.PlanId, e.PayerId, e.Amount) + if err != nil { + return expenses, fmt.Errorf( + "Unable to scan expense during list: %w", + err, + ) + } + expenses = append(expenses, e) } return expenses, nil } -func GetExpense(db *gorm.DB, expense_id uint) (*Expense, error) { - expense := new(Expense) - expense.ID = expense_id - if err := db.Take(expense).Error; err != nil { - return nil, err +func ExpensesGet(db *sql.DB, expense_id uint) (*Expense, error) { + e := Expense{} + row := db.QueryRow( + "SELECT id,plan_id,payer_id,amount FROM expenses WHERE id=?", + expense_id, + ) + if err := row.Scan(&e.Id, &e.PlanId, &e.PayerId, &e.Amount); err != nil { + return nil, fmt.Errorf("Unable to get expense %d: %w", expense_id, err) } - return expense, nil + return &e, nil } diff --git a/core/models.go b/core/models.go index f4bceb8..f5fe490 100644 --- a/core/models.go +++ b/core/models.go @@ -3,64 +3,52 @@ package core import "github.com/shopspring/decimal" type User struct { - Username string `gorm:"primaryKey" json:"username"` - Password string `json:"password"` - OwnedPlans []Plan `gorm:"foreignKey:Owner;references:Username" json:"-"` - MemberOf []Member `json:"-"` - Votes []Vote `gorm:"foreignKey:UsernameID" json:"-"` + Username string `json:"username"` + Password string `json:"password"` } type Member struct { - ID uint `gorm:"primaryKey;autoIncrement:true" json:"-"` - PlanID uint `json:"-"` - Plan Plan `json:"-"` - Name string `gorm:"check:user_id IS NOT NULL OR name IS NOT NULL" json:"name,omitempty"` - UserID string `json:"username,omitempty"` - User User `gorm:"foreignKey:UserID" json:"-"` + Id uint `json:"-"` + PlanId uint `json:"-"` + Name string `json:"name,omitempty"` + UserId *string `json:"username,omitempty"` } // CREATE TABLE plans(id INTEGER PRIMARY KEY AUTOINCREMENT, name STRING, owner STRING, FOREIGN KEY(owner) REFERENCES users(username)) // CREATE TABLE plan_user_relations(username STRING, plan INTEGER, PRIMARY KEY(username, plan), FOREIGN KEY username REFERENCES user(username), FOREIGN KEY plan REFERENCES plans(id)) type Plan struct { - ID uint `gorm:"primaryKey;autoIncrement:true" json:"id"` - Name string `json:"name"` - Owner string `json:"owner"` - Description string `json:"description"` - JoinCode string `gorm:"not null" json:"join_code,omitempty"` - Members []Member `json:"-"` - Polls []Poll `gorm:"foreignKey:PlanID;references:ID" json:"-"` + Id uint `json:"id"` + Name string `json:"name"` + Owner string `json:"owner"` + Description string `json:"description"` + JoinCode string `json:"join_code,omitempty"` } type Expense struct { - ID uint `gorm:"primaryKey;autoIncrement:true" json:"id"` - PlanID uint `json:"-"` - Plan Plan `json:"-"` - PayerID uint `json:"-"` - Payer Member `gorm:"foreignKey:PayerID" json:"-"` + Id uint `json:"id"` + PlanId uint `json:"-"` + PayerId uint `json:"-"` Amount decimal.Decimal `json:"amount"` - Debts []Debt `gorm:"foreignKey:ExpenseID" json:"debts,omitempty"` } type Debt struct { - ExpenseID uint `gorm:"primaryKey" json:"-"` - Expense Expense `gorm:"foreignKey:ExpenseID"` - DebtorID uint `gorm:"primaryKey" json:"-"` - Debtor Member `gorm:"foreignKey:DebtorID"` + Id uint + ExpenseId uint `json:"-"` + DebtorId uint `json:"-"` Amount decimal.Decimal Paid decimal.Decimal } // CREATE TABLE polls(id INTEGER PRIMARY KEY AUTOINCREMENT, plan INTEGER, name STRING, options JSON, FOREIGN KEY plan REFERENCES plans(id)) type Poll struct { - ID uint `gorm:"primaryKey;autoIncrement:true" json:"id"` - PlanID uint `json:"-"` + Id uint `json:"id"` + PlanId uint `json:"-"` Options string `json:"options"` - Votes []Vote `gorm:"foreignKey:PollID;references:ID" json:"-"` } // CREATE TABLE votes(id INTEGER, poll INTEGER, user STRING, value JSON, FOREIGN KEY poll REFERENCES polls(id), FOREIGN KEY user REFERENCES user(username)) type Vote struct { - PollID uint `gorm:"primaryKey" json:"-"` - UsernameID string `gorm:"primaryKey" json:"username_id"` - Value string `json:"value"` + PollId uint `json:"-"` + MemberId uint `json:"member_id"` + Value string `json:"value"` } diff --git a/core/plans.go b/core/plans.go index 28403c9..ac56514 100644 --- a/core/plans.go +++ b/core/plans.go @@ -2,13 +2,13 @@ package core import ( "crypto/rand" + "database/sql" "encoding/base64" "errors" - - "gorm.io/gorm" + "fmt" ) -func CreatePlan(db *gorm.DB, user *User, name string) (*Plan, error) { +func PlanCreate(db *sql.DB, user *User, name string) (*Plan, error) { join_code := make([]byte, 32) _, err := rand.Read(join_code) if err != nil { @@ -16,100 +16,101 @@ func CreatePlan(db *gorm.DB, user *User, name string) (*Plan, error) { } var plan Plan = Plan{ - Name: name, - Owner: user.Username, - Members: []Member{ - { - UserID: user.Username, - }, - }, + Name: name, + Owner: user.Username, JoinCode: base64.URLEncoding.EncodeToString(join_code), } - result := db.Create(&plan) - if result.Error != nil { - return nil, result.Error + row := db.QueryRow( + "INSERT INTO plans(name, owner, description, join_jode) VALUES (?, ?, '', ?) RETURNING id", + plan.Name, + plan.Owner, + plan.JoinCode, + ) + if err := row.Scan(&plan.Id); err != nil { + return nil, err } return &plan, nil } -func GetPlan(orm *gorm.DB, id uint) (*Plan, error) { - var plan Plan = Plan{ - ID: id, +func PlanGet(db *sql.DB, id uint) (*Plan, error) { + plan := Plan{} + row := db.QueryRow("SELECT name,owner,description,join_code FROM plans WHERE id=?", id) + if err := row.Scan(&plan.Name, &plan.Owner, &plan.Description, &plan.JoinCode); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("Unexpected database error: %w", err) } - result := orm.Take(&plan) - - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, ErrNotFound - } else if result.Error != nil { - return nil, result.Error - } - return &plan, nil } -func (p *Plan) GetAllUsers(orm *gorm.DB) ([]Member, error) { - var members []Member - err := orm.Model(p).Association("Members").Find(&members) +func (p *Plan) GetAllUsers(db *sql.DB) ([]Member, error) { + members := []Member{} + rows, err := db.Query("SELECT id,plan_id,name,user_id FROM members WHERE plan_id=?", p.Id) + if err != nil { + return members, err + } + defer rows.Close() + for rows.Next() { + var m Member + rows.Scan(&m.Id, &m.PlanId, &m.Name, &m.UserId) + members = append(members, m) + } return members, err } -func (p *Plan) IsMember(orm *gorm.DB, u *User) (bool, error) { - var member_count int64 - result := orm. - Table("members"). - Where("user_id=? AND plan_id=?", u.Username, p.ID). - Count(&member_count).Error - if result != nil { - return false, ErrNotMember +func (p *Plan) IsMember(db *sql.DB, u *User) (bool, error) { + var member_count int + row := db.QueryRow("SELECT count(1) FROM members WHERE plan_id=? AND user_id=?", p.Id, u.Username) + if err := row.Scan(&member_count); err != nil { + return false, nil } return member_count == 1, nil } -func (p *Plan) GetMember(orm *gorm.DB, u *User) (*Member, error) { +func (p *Plan) GetMember(db *sql.DB, u *User) (*Member, error) { var m Member - result := orm. - Table("members"). - Where("user_id=? AND plan_id=?", u.Username, p.ID). - Take(&m).Error - if result != nil { - return nil, ErrNotMember + row := db.QueryRow("SELECT id,plan_id,name,user_id FROM members WHERE plan_id=? AND user_id=?", p.Id, u.Username) + if err := row.Scan(&m.Id, &m.PlanId, &m.Name, &m.UserId); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotMember + } + return nil, fmt.Errorf("Unable to check if user %s is member in plan %d: %w", u.Username, p.Id, err) } return &m, nil } -func (p *Plan) HasNonUser(orm *gorm.DB, name string) (bool, error) { - var member_count int64 - result := orm. - Table("members"). - Where("plan_id=? AND name=?", p.ID, name). - Count(&member_count).Error - if result != nil { - return false, result +func (p *Plan) HasNonUser(db *sql.DB, name string) (bool, error) { + var member_count int + row := db.QueryRow("SELECT count(1) FROM members WHERE plan_id=? AND name=?", p.Id, name) + if err := row.Scan(&member_count); err != nil { + return false, nil } return member_count == 1, nil } -func (p *Plan) AddMember(orm *gorm.DB, new_member *Member) error { +func (p *Plan) AddMember(db *sql.DB, new_member *Member) error { if new_member == nil { return errors.New("Member is nil") } - new_member.PlanID = p.ID + new_member.PlanId = p.Id if new_member.Name != "" { - found, err := p.HasNonUser(orm, new_member.Name) + found, err := p.HasNonUser(db, new_member.Name) if err != nil { return nil } if found { return errors.New("Non user name taken") } - return orm.Create(&new_member).Error - } else if new_member.UserID != "" { - user, err := GetUser(orm, new_member.UserID) + _, err = db.Exec("INSERT INTO members(plan_id, name) VALUES (?,?)", new_member.PlanId, new_member.Name) + return err + } else if *new_member.UserId != "" { + user, err := UserGet(db, *new_member.UserId) if err != nil { return err } - found, err := p.IsMember(orm, &user) + found, err := p.IsMember(db, &user) if err != nil { return nil } @@ -117,7 +118,8 @@ func (p *Plan) AddMember(orm *gorm.DB, new_member *Member) error { return errors.New("User already is member") } - return orm.Create(&new_member).Error + _, err = db.Exec("INSERT INTO members(plan_id, user_id) VALUES (?,?)", new_member.PlanId, new_member.UserId) + return err } else { return errors.New("Member object requires one of Name or UserID to be filled") } diff --git a/core/polls.go b/core/polls.go index 564d0b5..7401c30 100644 --- a/core/polls.go +++ b/core/polls.go @@ -1,43 +1,59 @@ package core import ( + "database/sql" "fmt" + "slices" "strings" - - "gorm.io/gorm" ) -func GetPoll(orm *gorm.DB, user User, id uint) (*Poll, error) { - var poll Poll = Poll{ - ID: id, +func (p *Poll) Create(db *sql.DB) error { + row := db.QueryRow( + "INSERT INTO polls(plan_id, options) VALUES(?,?) RETURNING id", + p.PlanId, + p.Options, + ) + if err := row.Scan(&p.Id); err != nil { + return fmt.Errorf("Unable to create new poll for plan %d: %w", p.PlanId, err) } + return nil +} - if result := orm.Take(&poll); result.Error != nil { - return nil, result.Error +func PollGet(db *sql.DB, user User, id uint) (*Poll, error) { + var poll Poll + row := db.QueryRow("SELECT id,plan_id,options FROM polls WHERE id=?", id) + if err := row.Scan(&poll.Id, &poll.PlanId, &poll.Options); err != nil { + return nil, err } - fmt.Printf("%+v\n", poll.PlanID) - return &poll, nil } -func (p *Poll) SetVote(orm *gorm.DB, user User, option string) error { - found := false +func PollsList(db *sql.DB, plan_id int) ([]Poll, error) { + var polls []Poll + rows, err := db.Query("SELECT id,plan_id,options FROM polls WHERE plan_id=?", plan_id) + if err != nil { + return polls, fmt.Errorf("Unable to query polls for plan %d: %w", plan_id, err) + } + defer rows.Close() + for rows.Next() { + var poll Poll + if err := rows.Scan(&poll.Id, &poll.PlanId, &poll.Options); err != nil { + return nil, fmt.Errorf("Unable to scan polls for plan %d: %w", plan_id, err) + } + polls = append(polls, poll) + } + + return polls, nil +} + +func (p *Poll) SetVote(db *sql.DB, member Member, option string) error { options := strings.Split(p.Options, ",") - for _, opt := range options { - if opt == option { - found = true - break - } - } - if !found { - return ErrInvalidOption + if !slices.Contains(options, option) { + return fmt.Errorf("%s is not a valid option (%s): %w", option, options, ErrInvalidOption) } + _, err := db.Exec("INSERT INTO votes(poll_id,member_id,value) VALUES (?,?,?)", p.Id, member.Id, option) - if res := orm.Create(Vote{PollID: p.ID, UsernameID: user.Username, Value: option}); res.Error != nil { - return res.Error - } - - return nil + return err } diff --git a/core/sqlexecutor.go b/core/sqlexecutor.go new file mode 100644 index 0000000..89bfdb2 --- /dev/null +++ b/core/sqlexecutor.go @@ -0,0 +1,9 @@ +package core + +import "database/sql" + +type SqlExecutor interface { + Exec(query string, args ...interface{}) (sql.Result, error) + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row +} diff --git a/core/users.go b/core/users.go index b863ec6..9f0956f 100644 --- a/core/users.go +++ b/core/users.go @@ -1,31 +1,49 @@ package core import ( + "database/sql" "errors" - - "gorm.io/gorm" + "fmt" ) -func (u *User) ListPlans(orm *gorm.DB) ([]Plan, error) { +func (u *User) ListPlans(db *sql.DB) ([]Plan, error) { var plans []Plan - err := orm.Debug().Table("plans p"). - Select("p.*"). - Joins("JOIN members m ON m.plan_id=p.id"). - Where("m.user_id=?", u.Username). - Find(&plans) - return plans, err.Error + rows, err := db.Query( + "SELECT p.id,p.name,p.owner,p.description,p.join_code FROM plans p "+ + "JOIN members m ON m.plan_id=p.id "+ + "WHERE m.user_id=?", + u.Username, + ) + if err != nil { + return plans, fmt.Errorf( + "Unable to query plans for user %s: %w", + u.Username, + err, + ) + } + defer rows.Close() + for rows.Next() { + p := Plan{} + err = rows.Scan(&p.Id, &p.Name, &p.Owner, &p.Description, &p.JoinCode) + if err != nil { + return plans, fmt.Errorf( + "Unable to scan plan for user %s: %w", + u.Username, + err, + ) + } + plans = append(plans, p) + } + return plans, nil } -func (u *User) GetPlan(db *gorm.DB, plan_id uint) (Plan, error) { - var plan Plan = Plan{ - ID: plan_id, - } - result := db.Take(&plan) +func (u *User) GetPlan(db *sql.DB, plan_id uint) (*Plan, error) { + plan, err := PlanGet(db, plan_id) - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return plan, ErrNotFound - } else if result.Error != nil { - return plan, result.Error + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotFound + } else if err != nil { + return nil, err } if plan.Owner == u.Username { @@ -33,18 +51,43 @@ func (u *User) GetPlan(db *gorm.DB, plan_id uint) (Plan, error) { } isMember, err := plan.IsMember(db, u) - if !isMember || err != nil { - return plan, err + return nil, err } return plan, nil } -func GetUser(orm *gorm.DB, username string) (User, error) { - user := User{Username: username} - if err := orm.Take(&user).Error; err != nil { - return user, err +func (u *User) GetMemberFromPlan(db *sql.DB, p Plan) (*Member, error) { + m := Member{} + row := db.QueryRow( + "SELECT id,plan_id,name,user_id FROM members WHERE user_id=? AND plan_id=?", + u.Username, + p.Id, + ) + if err := row.Scan(&m.Id, &m.PlanId, &m.Name, &m.UserId); err != nil { + return nil, fmt.Errorf( + "Unable to get member from user %s and plan %d: %w", + u.Username, + p.Id, + err, + ) } - return user, nil + return &m, nil +} + +func UserGet(db *sql.DB, username string) (User, error) { + var user User + row := db.QueryRow("SELECT username, password FROM users WHERE username=?", username) + err := row.Scan(&user.Username, &user.Password) + return user, err +} + +func (u *User) Create(db *sql.DB) error { + _, err := db.Exec( + "INSERT INTO users(username, password) VALUES (?,?)", + u.Username, + u.Password, + ) + return fmt.Errorf("Unable to create user %s: %w", u.Username, err) } diff --git a/core/votes.go b/core/votes.go new file mode 100644 index 0000000..81ddd0b --- /dev/null +++ b/core/votes.go @@ -0,0 +1,23 @@ +package core + +import ( + "database/sql" + "fmt" +) + +func VotesList(db *sql.DB, poll_id int) ([]Vote, error) { + votes := []Vote{} + rows, err := db.Query("SELECT poll_id,member_id,value FROM polls WHERE poll_id=?", poll_id) + if err != nil { + return votes, fmt.Errorf("Unable to get votes for poll %d: %w", poll_id, err) + } + defer rows.Close() + for rows.Next() { + v := Vote{} + err = rows.Scan(&v.PollId, &v.MemberId, &v.Value) + if err != nil { + return votes, fmt.Errorf("Unable to scan vote for poll %d: %w", poll_id, err) + } + } + return votes, nil +} diff --git a/db.go b/db.go index b0e63a8..c90bf34 100644 --- a/db.go +++ b/db.go @@ -1,45 +1,33 @@ package main import ( - "gorm.io/driver/sqlite" - "gorm.io/gorm" + "database/sql" "log" "os" - . "planner/core" + + _ "github.com/mattn/go-sqlite3" ) -func bootstrapDatabase() *gorm.DB { +func bootstrapDatabase() *sql.DB { fi, err := os.Stat("./db.sqlite") - var db *gorm.DB + var db *sql.DB if err != nil { if os.IsNotExist(err) { - db, err := gorm.Open(sqlite.Open("./db.sqlite"), &gorm.Config{}) + db, err := sql.Open("sqlite3", "./db.sqlite") if err != nil { log.Fatal(err) return nil } - db.AutoMigrate(&User{}, &Member{}, &Plan{}, &Expense{}, &Debt{}, &Poll{}, &Vote{}) - - //var tables = [...]struct { - // key string - // query string - //}{ - // {key: "users", query: "CREATE TABLE users(username STRING PRIMARY KEY, password STRING)"}, - // {key: "plans", query: "CREATE TABLE plans(id INTEGER PRIMARY KEY AUTOINCREMENT, name STRING, owner STRING, FOREIGN KEY(owner) REFERENCES users(username))"}, - // {key: "plan_user_relations", query: "CREATE TABLE plan_user_relations(username STRING, plan INTEGER, PRIMARY KEY(username, plan), FOREIGN KEY username REFERENCES user(username), FOREIGN KEY plan REFERENCES plans(id))"}, - // {key: "polls", query: "CREATE TABLE polls(id INTEGER PRIMARY KEY AUTOINCREMENT, plan INTEGER, name STRING, options JSON, FOREIGN KEY plan REFERENCES plans(id))"}, - // {key: "votes", query: "CREATE TABLE votes(id INTEGER, poll INTEGER, user STRING, value JSON, FOREIGN KEY poll REFERENCES polls(id), FOREIGN KEY user REFERENCES user(username))"}, - //} - - //for _, table := range tables { - // _, err = db.Exec(table.query) - // if err != nil { - // log.Fatal("Failed to create " + table.key + " table") - // log.Fatal(err) - // return false - // } - //} + file, err := os.ReadFile("./base.sql") + if err != nil { + log.Fatal("Unable to open base.sql", err) + } + contents := string(file) + _, err = db.Exec(contents) + if err != nil { + log.Fatal("Unable to create base db", err) + } return db } else { log.Fatal(err) @@ -52,7 +40,7 @@ func bootstrapDatabase() *gorm.DB { return nil } - db, err = gorm.Open(sqlite.Open("./db.sqlite"), &gorm.Config{}) + db, err = sql.Open("sqlite3", "./db.sqlite") if err != nil { log.Fatal(err) diff --git a/go.mod b/go.mod index 9bef346..c36d2e7 100644 --- a/go.mod +++ b/go.mod @@ -4,9 +4,8 @@ go 1.23.0 require ( github.com/gin-gonic/gin v1.10.0 - github.com/mattn/go-sqlite3 v1.14.23 - gorm.io/driver/sqlite v1.5.6 - gorm.io/gorm v1.25.12 + github.com/mattn/go-sqlite3 v1.14.27 + github.com/shopspring/decimal v1.4.0 ) require ( @@ -20,8 +19,6 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.22.0 // indirect github.com/goccy/go-json v0.10.3 // indirect - github.com/jinzhu/inflection v1.0.0 // indirect - github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/leodido/go-urn v1.4.0 // indirect @@ -29,7 +26,6 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect - github.com/shopspring/decimal v1.4.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect golang.org/x/arch v0.10.0 // indirect diff --git a/go.sum b/go.sum index 96934a0..f23107e 100644 --- a/go.sum +++ b/go.sum @@ -45,6 +45,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.23 h1:gbShiuAP1W5j9UOksQ06aiiqPMxYecovVGwmTxWtuw0= github.com/mattn/go-sqlite3 v1.14.23/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-sqlite3 v1.14.27 h1:drZCnuvf37yPfs95E5jd9s3XhdVWLal+6BOK6qrv6IU= +github.com/mattn/go-sqlite3 v1.14.27/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= diff --git a/planner.go b/planner.go index f30a018..907eed2 100644 --- a/planner.go +++ b/planner.go @@ -12,17 +12,16 @@ import ( func main() { fmt.Println("Opening database db.sqlite") - orm := bootstrapDatabase() + db := bootstrapDatabase() - if orm == nil { + if db == nil { log.Fatal("Failed to init database") return } - db, _ := orm.DB() defer db.Close() r := gin.Default() - apis.BindAPIs(r, orm) + apis.BindAPIs(r, db) r.Run() }