diff --git a/apis/base.go b/apis/base.go index bbf34da..a577c85 100644 --- a/apis/base.go +++ b/apis/base.go @@ -1,16 +1,16 @@ package apis import ( + "database/sql" "errors" "net/http" "github.com/gin-gonic/gin" - "gorm.io/gorm" ) -var db *gorm.DB +var db *sql.DB -func BindAPIs(r *gin.Engine, cfg_db *gorm.DB) error { +func BindAPIs(r *gin.Engine, cfg_db *sql.DB) error { if cfg_db == nil { return errors.New("Database is null") } @@ -23,5 +23,6 @@ func BindAPIs(r *gin.Engine, cfg_db *gorm.DB) error { bindPlanAPIs(r) bindPollAPIs(r) bindUserAPIs(r) + bindExpensesAPIs(r) return nil } diff --git a/apis/expenses.go b/apis/expenses.go new file mode 100644 index 0000000..68a7685 --- /dev/null +++ b/apis/expenses.go @@ -0,0 +1,230 @@ +package apis + +import ( + "net/http" + "planner/core" + "slices" + "strconv" + + "github.com/gin-gonic/gin" + "github.com/shopspring/decimal" +) + +type ExpenseDef struct { + Payer string `json:"payer"` + Type string `json:"type"` + Amount decimal.Decimal `json:"amount"` +} + +func createExpense(c *gin.Context) { + u := extractUser(db, c) + if u == nil { + c.Status(http.StatusUnauthorized) + return + } + + plan_id, err := strconv.ParseInt(c.Param("id"), 10, 32) + if err != nil || plan_id < 0 { + c.Status(http.StatusBadRequest) + return + } + plan, err := u.GetPlan(db, uint(plan_id)) + if err != nil { + c.Status(http.StatusBadRequest) + return + } + + members, err := plan.GetAllUsers(db) + if err != nil { + c.Status(http.StatusInternalServerError) + return + } + var exp ExpenseDef + if err := c.ShouldBind(&exp); err != nil { + c.Status(http.StatusBadRequest) + return + } + + var compare func(m core.Member) bool + + if exp.Type == "user" { + compare = func(m core.Member) bool { + return *m.UserId == exp.Payer + } + } else if exp.Type == "non-user" { + compare = func(m core.Member) bool { + return m.Name == exp.Payer + } + } else { + c.String(http.StatusBadRequest, "Invalid member type") + return + } + idx := slices.IndexFunc(members, compare) + if idx != -1 { + expense, err := core.ExpenseCreate(db, *plan, members[idx], exp.Amount) + if err != nil { + c.Status(http.StatusInternalServerError) + return + } + c.JSON(http.StatusOK, expense) + return + } + c.String(http.StatusBadRequest, "Unable to found member") +} + +func listExpenses(c *gin.Context) { + u := extractUser(db, c) + if u == nil { + c.Status(http.StatusUnauthorized) + return + } + + plan_id, err := strconv.ParseInt(c.Param("id"), 10, 32) + if err != nil || plan_id < 0 { + c.Status(http.StatusBadRequest) + return + } + plan, err := u.GetPlan(db, uint(plan_id)) + if err != nil { + c.Status(http.StatusBadRequest) + return + } + expenses, err := core.ExpensesList(db, *plan) + if err != nil { + c.Status(http.StatusInternalServerError) + return + } + c.JSON(http.StatusOK, expenses) +} + +func getExpense(c *gin.Context) { + u := extractUser(db, c) + if u == nil { + c.Status(http.StatusUnauthorized) + return + } + + expense_id, err := strconv.ParseInt(c.Param("id"), 10, 32) + if err != nil || expense_id < 0 { + c.Status(http.StatusBadRequest) + return + } + expense, err := core.ExpensesGet(db, uint(expense_id)) + if err != nil { + c.Status(http.StatusBadRequest) + return + } + _, err = u.GetPlan(db, expense.PlanId) + if err != nil { + c.Status(http.StatusBadRequest) + return + } + + c.JSON(http.StatusOK, expense) +} + +func deleteExpense(c *gin.Context) { + u := extractUser(db, c) + if u == nil { + c.Status(http.StatusUnauthorized) + return + } + + expense_id, err := strconv.ParseInt(c.Param("id"), 10, 32) + if err != nil || expense_id < 0 { + c.Status(http.StatusBadRequest) + return + } + expense, err := core.ExpensesGet(db, uint(expense_id)) + if err != nil { + c.Status(http.StatusBadRequest) + return + } + _, err = u.GetPlan(db, expense.PlanId) + if err != nil { + c.Status(http.StatusBadRequest) + return + } + + if err = expense.Delete(db); err != nil { + c.Status(http.StatusInternalServerError) + return + } + + c.Status(http.StatusOK) +} + +func getExpenseDebts(c *gin.Context) { + u := extractUser(db, c) + if u == nil { + c.Status(http.StatusUnauthorized) + return + } + + expense_id, err := strconv.ParseInt(c.Param("id"), 10, 32) + if err != nil || expense_id < 0 { + c.Status(http.StatusBadRequest) + return + } + expense, err := core.ExpensesGet(db, uint(expense_id)) + if err != nil { + c.Status(http.StatusBadRequest) + return + } + _, err = u.GetPlan(db, expense.PlanId) + if err != nil { + c.Status(http.StatusBadRequest) + return + } + + debts, err := expense.GetDebt(db) + if err != nil { + c.Status(http.StatusInternalServerError) + return + } + + c.JSON(http.StatusOK, debts) +} + +func setExpenseDebts(c *gin.Context) { + u := extractUser(db, c) + if u == nil { + c.Status(http.StatusUnauthorized) + return + } + + expense_id, err := strconv.ParseInt(c.Param("id"), 10, 32) + if err != nil || expense_id < 0 { + c.Status(http.StatusBadRequest) + return + } + expense, err := core.ExpensesGet(db, uint(expense_id)) + if err != nil { + c.Status(http.StatusBadRequest) + return + } + _, err = u.GetPlan(db, expense.PlanId) + if err != nil { + c.Status(http.StatusBadRequest) + return + } + + var debtSpec []core.DebtSpec + if err := c.ShouldBind(&debtSpec); err != nil { + c.Status(http.StatusBadRequest) + return + } + + //expense.SetDebt(db, debtSpec) + + c.JSON(http.StatusOK, debtSpec) +} + +func bindExpensesAPIs(r *gin.Engine) { + r.POST("/plans/:id/expenses", createExpense) + r.GET("/plans/:id/expenses", listExpenses) + r.GET("/expenses/:id", getExpense) + r.DELETE("/expenses/:id", deleteExpense) + r.GET("/expenses/:id/debts", getExpenseDebts) + r.POST("/expenses/:id/debts", setExpenseDebts) +} diff --git a/apis/extract.go b/apis/extract.go index 7ed4f7d..0a9cd86 100644 --- a/apis/extract.go +++ b/apis/extract.go @@ -1,24 +1,22 @@ package apis import ( - "github.com/gin-gonic/gin" - "gorm.io/gorm" + "database/sql" "net/http" "planner/core" + + "github.com/gin-gonic/gin" ) -func extractUser(orm *gorm.DB, c *gin.Context) *core.User { +func extractUser(orm *sql.DB, c *gin.Context) *core.User { username, _, ok := c.Request.BasicAuth() if !ok { c.Status(http.StatusUnauthorized) return nil } - u := core.User{ - Username: username, - } - result := orm.Take(&u) - if result.Error != nil { + u, err := core.UserGet(db, username) + if err != nil { c.String(http.StatusNotFound, "Unable to find user "+username) return nil } diff --git a/apis/plans.go b/apis/plans.go index 2b40112..7430277 100644 --- a/apis/plans.go +++ b/apis/plans.go @@ -8,6 +8,7 @@ import ( "github.com/gin-gonic/gin" + "planner/core" . "planner/core" ) @@ -22,7 +23,7 @@ func createPlan(c *gin.Context) { } c.Bind(&plan_req) - _, err := CreatePlan(db, u, plan_req.Name) + _, err := PlanCreate(db, u, plan_req.Name) if err != nil { c.JSON(http.StatusInternalServerError, err) @@ -119,7 +120,7 @@ func addPlanMember(c *gin.Context) { return } - err = plan.AddMember(db, &Member{Name: new_member.Name, Type: "non-user"}) + err = plan.AddMember(db, &Member{Name: new_member.Name}) if err == nil { c.JSON(http.StatusOK, new_member) @@ -144,7 +145,7 @@ func joinPlan(c *gin.Context) { c.Status(http.StatusBadRequest) return } - plan, err := GetPlan(db, uint(plan_id)) + plan, err := PlanGet(db, uint(plan_id)) if err != nil || plan == nil { c.Status(http.StatusInternalServerError) return @@ -171,7 +172,7 @@ func joinPlan(c *gin.Context) { c.String(http.StatusConflict, "User already a member") return } - plan.AddMember(db, &Member{Type: "user", UserID: user.Username}) + plan.AddMember(db, &Member{UserId: &user.Username}) c.Status(http.StatusOK) } @@ -212,10 +213,10 @@ func createPlanPoll(c *gin.Context) { } poll := Poll{ - PlanID: plan.ID, + PlanId: plan.Id, Options: poll_opts.Options, } - db.Create(&poll) + poll.Create(db) c.JSON(http.StatusCreated, poll) } @@ -236,8 +237,11 @@ func listPlanPolls(c *gin.Context) { return } - var polls []Poll - db.Where("plan_id = ?", params.Id).Find(&polls) + polls, err := core.PollsList(db, int(params.Id)) + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } c.JSON(http.StatusOK, polls) } diff --git a/apis/polls.go b/apis/polls.go index 93d43d6..fa91666 100644 --- a/apis/polls.go +++ b/apis/polls.go @@ -25,7 +25,7 @@ func getPoll(c *gin.Context) { } fmt.Println(params) - poll, _ := core.GetPoll(db, *user, params.PollId) + poll, _ := core.PollGet(db, *user, params.PollId) c.JSON(http.StatusOK, poll) } @@ -45,9 +45,12 @@ func getPollVotes(c *gin.Context) { return } - var votes []core.Vote - db.Where("poll_id = ?", params.PollId).Find(&votes) - c.JSON(http.StatusOK, &votes) + votes, err := core.VotesList(db, int(params.PollId)) + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + c.JSON(http.StatusOK, votes) } func pollVote(c *gin.Context) { @@ -66,7 +69,7 @@ func pollVote(c *gin.Context) { return } - poll, err := core.GetPoll(db, *user, path_params.PollId) + poll, err := core.PollGet(db, *user, path_params.PollId) if err != nil { c.String(http.StatusInternalServerError, err.Error()) return @@ -77,7 +80,19 @@ func pollVote(c *gin.Context) { } c.Bind(&vote_params) - if err := poll.SetVote(db, *user, vote_params.Vote); err != nil { + plan, err := core.PlanGet(db, poll.PlanId) + if err != nil { + c.String(http.StatusInternalServerError, "Unable to find plan: "+err.Error()) + return + } + + member, err := user.GetMemberFromPlan(db, *plan) + if err != nil { + c.String(http.StatusInternalServerError, "Unable to find member: "+err.Error()) + return + } + + if err := poll.SetVote(db, *member, vote_params.Vote); err != nil { c.String(http.StatusBadRequest, err.Error()) } diff --git a/apis/users.go b/apis/users.go index 16ad82a..3fd802c 100644 --- a/apis/users.go +++ b/apis/users.go @@ -6,7 +6,7 @@ import ( "github.com/gin-gonic/gin" - . "planner/core" + "planner/core" ) func getUserByName(c *gin.Context) { @@ -14,19 +14,23 @@ func getUserByName(c *gin.Context) { Name string `fdb:"name"` } if c.ShouldBind(&q) == nil { - user := User{ - Username: q.Name, + user, err := core.UserGet(db, q.Name) + if err != nil { + c.String(http.StatusInternalServerError, "Unable to get user: "+err.Error()) + return } - db.Take(&user) - fmt.Println(user) c.JSON(http.StatusOK, user) } } func createUser(c *gin.Context) { - var u User + var u core.User if c.ShouldBind(&u) == nil { - db.Create(&u) + err := u.Create(db) + if err != nil { + c.String(http.StatusInternalServerError, "Unable to create user: "+err.Error()) + return + } c.Status(http.StatusCreated) } else { fmt.Print("Could not bind model") @@ -43,12 +47,13 @@ func login(c *gin.Context) { if q.Username == "" { c.String(http.StatusBadRequest, "Login data is null") } else { - user := User{ - Username: q.Username, + user, err := core.UserGet(db, q.Username) + if err != nil { + c.String(http.StatusInternalServerError, "Unable to get user: "+err.Error()) + return } - db.Take(&user) if user.Password == q.Password { - c.JSON(http.StatusOK, map[string]string{"username": user.Username}) + c.JSON(http.StatusOK, gin.H{"username": user.Username}) } else { c.Status(http.StatusForbidden) } diff --git a/base.sql b/base.sql new file mode 100644 index 0000000..ec3cfb8 --- /dev/null +++ b/base.sql @@ -0,0 +1,47 @@ +CREATE TABLE users ( + username TEXT PRIMARY KEY, + password TEXT NOT NULL +); + +CREATE TABLE plans ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + owner TEXT REFERENCES users(username) NOT NULL, + description TEXT DEFAULT '', + join_code TEXT NOT NULL +); + +CREATE TABLE members ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + plan_id INTEGER REFERENCES plans(id) NOT NULL, + name TEXT NOT NULL, + user_id TEXT REFERENCES users(username) +); + +CREATE TABLE polls ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + plan_id INTEGER REFERENCES plans(id) NOT NULL, + options TEXT DEFAULT '' +); + +CREATE TABLE votes ( + poll_id INTEGER REFERENCES polls(id), + member_id INTEGER REFERENCES members(id), + value text, + PRIMARY KEY (poll_id,member_id) +); + +CREATE TABLE expenses ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + plan_id INTEGER REFERENCES plans(id) NOT NULL, + payer_id INTEGER REFERENCES members(id) NOT NULL, + amount DECIMAL NOT NULL +); + +CREATE TABLE debts ( + expense_id INTEGER REFERENCES expenses(id) NOT NULL, + debtor_id INTEGER REFERENCES members(id) NOT NULL, + amount DECIMAL NOT NULL, + paid DECIMAL, + PRIMARY KEY (expense_id,debtor_id) +); diff --git a/core/debts.go b/core/debts.go new file mode 100644 index 0000000..5fcbc8b --- /dev/null +++ b/core/debts.go @@ -0,0 +1,12 @@ +package core + +func (d *Debt) Create(db SqlExecutor) error { + _, err := db.Exec( + "INSERT INTO debts(expense_id, debtor_id, amount, paid) VALUES (?,?,?,?)", + d.ExpenseId, + d.DebtorId, + d.Amount, + d.Paid, + ) + return err +} diff --git a/core/expenses.go b/core/expenses.go new file mode 100644 index 0000000..fbdae05 --- /dev/null +++ b/core/expenses.go @@ -0,0 +1,196 @@ +package core + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" + + "github.com/shopspring/decimal" +) + +type DebtType int + +const ( + ProportionalDebt DebtType = iota + AbsoluteDebt = iota +) + +func (dt DebtType) String() string { + return [...]string{"proportional", "absolute"}[dt] +} + +func (dt DebtType) MarshalJSON() ([]byte, error) { + return json.Marshal(dt.String()) +} + +func (dt *DebtType) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return nil + } + if s == "proportional" { + *dt = ProportionalDebt + } else if s == "absolute" { + *dt = AbsoluteDebt + } else { + return errors.New("Invalid value for debt type") + } + return nil +} + +type DebtSpec struct { + Member Member `json:"member"` + Amount decimal.Decimal `json:"amount"` + DebtType DebtType `json:"debt_type"` +} + +func ExpenseCreate(db *sql.DB, plan Plan, payer Member, amount decimal.Decimal) (*Expense, error) { + expense := new(Expense) + expense.PlanId = plan.Id + expense.PayerId = payer.Id + expense.Amount = amount + row := db.QueryRow( + "INSERT INTO expenses(plan_id,payer_id,amount) VALUES (?,?,?) RETURNING id", + expense.PlanId, + expense.PayerId, + expense.Amount, + ) + if err := row.Scan(&expense.Id); err != nil { + return nil, err + } + return expense, nil +} + +func (e *Expense) GetDebt(db *sql.DB) ([]Debt, error) { + var debts []Debt + rows, err := db.Query("SELECT expense_id,debtor_id,amount,paid FROM debts WHERE expense_id=?") + if err != nil { + return debts, err + } + defer rows.Close() + for rows.Next() { + var d Debt + err = rows.Scan(&d.ExpenseId, &d.DebtorId, &d.Amount, &d.Paid) + if err != nil { + return []Debt{}, err + } + debts = append(debts, d) + } + return debts, nil +} + +func (e *Expense) SetDebt(db *sql.DB, debtSpec []DebtSpec) error { + abs_paid := decimal.Decimal{} + debts := make([]Debt, 0) + var prop_payers int64 = 0 + for _, debt := range debtSpec { + if debt.DebtType == ProportionalDebt { + prop_payers += 1 + } else { + debts = append(debts, Debt{ + Amount: debt.Amount, + Paid: decimal.Decimal{}, + ExpenseId: e.Id, + DebtorId: debt.Member.Id, + }) + abs_paid.Add(debt.Amount) + } + } + if abs_paid.Cmp(e.Amount) > 0 { + return errors.New("Absolute pay amount is larger than debt") + } + + prop_debt := e.Amount.Sub(abs_paid).DivRound(decimal.NewFromInt(prop_payers), 2) + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + for _, debt := range debts { + if err := debt.Create(tx); err != nil { + return err + } + } + + for _, debt := range debtSpec { + if debt.DebtType != ProportionalDebt { + continue + } + debtObj := Debt{ + Amount: prop_debt, + Paid: decimal.Decimal{}, + ExpenseId: e.Id, + DebtorId: debt.Member.Id, + } + if err := debtObj.Create(tx); err != nil { + return err + } + } + return tx.Commit() +} + +func (e *Expense) Delete(db *sql.DB) error { + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + _, err = tx.Exec("DELETE FROM debts WHERE expense_id=?", e.Id) + if err != nil { + return fmt.Errorf("Unable to delete debts for expense %d: %w", e.Id, err) + } + _, err = tx.Exec("DELETE FROM expenses WHERE id=?", e.Id) + if err != nil { + return fmt.Errorf("Unable to delete expense %d: %w", e.Id, err) + } + err = tx.Commit() + if err != nil { + return fmt.Errorf( + "Unable to commit transaction when deleting expense %d: %w", + e.Id, + err, + ) + } + return nil +} + +func ExpensesList(db *sql.DB, plan Plan) ([]Expense, error) { + var expenses []Expense + rows, err := db.Query( + "SELECT id,plan_id,payer_id,amount FROM expenses WHERE plan_id=$1", + plan.Id, + ) + if err != nil { + return expenses, fmt.Errorf( + "Unable to query for expenses in plan %d: %w", + plan.Id, + err, + ) + } + defer rows.Close() + for rows.Next() { + e := Expense{} + err = rows.Scan(&e.Id, e.PlanId, e.PayerId, e.Amount) + if err != nil { + return expenses, fmt.Errorf( + "Unable to scan expense during list: %w", + err, + ) + } + expenses = append(expenses, e) + } + return expenses, nil +} + +func ExpensesGet(db *sql.DB, expense_id uint) (*Expense, error) { + e := Expense{} + row := db.QueryRow( + "SELECT id,plan_id,payer_id,amount FROM expenses WHERE id=?", + expense_id, + ) + if err := row.Scan(&e.Id, &e.PlanId, &e.PayerId, &e.Amount); err != nil { + return nil, fmt.Errorf("Unable to get expense %d: %w", expense_id, err) + } + return &e, nil +} diff --git a/core/models.go b/core/models.go index 53268c0..f5fe490 100644 --- a/core/models.go +++ b/core/models.go @@ -1,46 +1,54 @@ package core +import "github.com/shopspring/decimal" + type User struct { - Username string `gorm:"primaryKey" json:"username"` - Password string `json:"password"` - OwnedPlans []Plan `gorm:"foreignKey:Owner;references:Username" json:"-"` - MemberOf []Member `json:"-"` - Votes []Vote `gorm:"foreignKey:UsernameID" json:"-"` + Username string `json:"username"` + Password string `json:"password"` } type Member struct { - ID uint `gorm:"primaryKey;autoIncrement:true" json:"-"` - PlanID uint `json:"-"` - Plan Plan `json:"-"` - Type string `gorm:"check:type in ('user','non-user')" json:"type"` - Name string `gorm:"check:type=='member' OR name IS NOT NULL" json:"name"` - UserID string `json:"username"` - User User `gorm:"foreignKey:UserID" json:"-"` + Id uint `json:"-"` + PlanId uint `json:"-"` + Name string `json:"name,omitempty"` + UserId *string `json:"username,omitempty"` } // CREATE TABLE plans(id INTEGER PRIMARY KEY AUTOINCREMENT, name STRING, owner STRING, FOREIGN KEY(owner) REFERENCES users(username)) // CREATE TABLE plan_user_relations(username STRING, plan INTEGER, PRIMARY KEY(username, plan), FOREIGN KEY username REFERENCES user(username), FOREIGN KEY plan REFERENCES plans(id)) type Plan struct { - ID uint `gorm:"primaryKey;autoIncrement:true" json:"id"` - Name string `json:"name"` - Owner string `json:"owner"` - Description string `json:"description"` - JoinCode string `gorm:"not null" json:"join_code,omitempty"` - Members []Member `json:"-"` - Polls []Poll `gorm:"foreignKey:PlanID;references:ID" json:"-"` + Id uint `json:"id"` + Name string `json:"name"` + Owner string `json:"owner"` + Description string `json:"description"` + JoinCode string `json:"join_code,omitempty"` +} + +type Expense struct { + Id uint `json:"id"` + PlanId uint `json:"-"` + PayerId uint `json:"-"` + Amount decimal.Decimal `json:"amount"` +} + +type Debt struct { + Id uint + ExpenseId uint `json:"-"` + DebtorId uint `json:"-"` + Amount decimal.Decimal + Paid decimal.Decimal } // CREATE TABLE polls(id INTEGER PRIMARY KEY AUTOINCREMENT, plan INTEGER, name STRING, options JSON, FOREIGN KEY plan REFERENCES plans(id)) type Poll struct { - ID uint `gorm:"primaryKey;autoIncrement:true" json:"id"` - PlanID uint `json:"-"` + Id uint `json:"id"` + PlanId uint `json:"-"` Options string `json:"options"` - Votes []Vote `gorm:"foreignKey:PollID;references:ID" json:"-"` } // CREATE TABLE votes(id INTEGER, poll INTEGER, user STRING, value JSON, FOREIGN KEY poll REFERENCES polls(id), FOREIGN KEY user REFERENCES user(username)) type Vote struct { - PollID uint `gorm:"primaryKey" json:"-"` - UsernameID string `gorm:"primaryKey" json:"username_id"` - Value string `json:"value"` + PollId uint `json:"-"` + MemberId uint `json:"member_id"` + Value string `json:"value"` } diff --git a/core/plans.go b/core/plans.go index 32ad597..ac56514 100644 --- a/core/plans.go +++ b/core/plans.go @@ -2,13 +2,13 @@ package core import ( "crypto/rand" + "database/sql" "encoding/base64" "errors" - - "gorm.io/gorm" + "fmt" ) -func CreatePlan(db *gorm.DB, user *User, name string) (*Plan, error) { +func PlanCreate(db *sql.DB, user *User, name string) (*Plan, error) { join_code := make([]byte, 32) _, err := rand.Read(join_code) if err != nil { @@ -16,104 +16,101 @@ func CreatePlan(db *gorm.DB, user *User, name string) (*Plan, error) { } var plan Plan = Plan{ - Name: name, - Owner: user.Username, - Members: []Member{ - { - UserID: user.Username, - Type: "user", - }, - }, + Name: name, + Owner: user.Username, JoinCode: base64.URLEncoding.EncodeToString(join_code), } - result := db.Create(&plan) - if result.Error != nil { - return nil, result.Error + row := db.QueryRow( + "INSERT INTO plans(name, owner, description, join_jode) VALUES (?, ?, '', ?) RETURNING id", + plan.Name, + plan.Owner, + plan.JoinCode, + ) + if err := row.Scan(&plan.Id); err != nil { + return nil, err } return &plan, nil } -func GetPlan(orm *gorm.DB, id uint) (*Plan, error) { - var plan Plan = Plan{ - ID: id, +func PlanGet(db *sql.DB, id uint) (*Plan, error) { + plan := Plan{} + row := db.QueryRow("SELECT name,owner,description,join_code FROM plans WHERE id=?", id) + if err := row.Scan(&plan.Name, &plan.Owner, &plan.Description, &plan.JoinCode); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("Unexpected database error: %w", err) } - result := orm.Take(&plan) - - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, ErrNotFound - } else if result.Error != nil { - return nil, result.Error - } - return &plan, nil } -func (p *Plan) GetAllUsers(orm *gorm.DB) ([]Member, error) { - var members []Member - err := orm.Model(p).Association("Members").Find(&members) +func (p *Plan) GetAllUsers(db *sql.DB) ([]Member, error) { + members := []Member{} + rows, err := db.Query("SELECT id,plan_id,name,user_id FROM members WHERE plan_id=?", p.Id) + if err != nil { + return members, err + } + defer rows.Close() + for rows.Next() { + var m Member + rows.Scan(&m.Id, &m.PlanId, &m.Name, &m.UserId) + members = append(members, m) + } return members, err } -func (p *Plan) IsMember(orm *gorm.DB, u *User) (bool, error) { - var member_count int64 - result := orm. - Table("members"). - Where("user_id=? AND plan_id=?", u.Username, p.ID). - Count(&member_count).Error - if result != nil { - return false, ErrNotMember +func (p *Plan) IsMember(db *sql.DB, u *User) (bool, error) { + var member_count int + row := db.QueryRow("SELECT count(1) FROM members WHERE plan_id=? AND user_id=?", p.Id, u.Username) + if err := row.Scan(&member_count); err != nil { + return false, nil } return member_count == 1, nil } -func (p *Plan) GetMember(orm *gorm.DB, u *User) (*Member, error) { +func (p *Plan) GetMember(db *sql.DB, u *User) (*Member, error) { var m Member - result := orm. - Table("members"). - Where("user_id=? AND plan_id=?", u.Username, p.ID). - Take(&m).Error - if result != nil { - return nil, ErrNotMember + row := db.QueryRow("SELECT id,plan_id,name,user_id FROM members WHERE plan_id=? AND user_id=?", p.Id, u.Username) + if err := row.Scan(&m.Id, &m.PlanId, &m.Name, &m.UserId); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotMember + } + return nil, fmt.Errorf("Unable to check if user %s is member in plan %d: %w", u.Username, p.Id, err) } return &m, nil } -func (p *Plan) HasNonUser(orm *gorm.DB, name string) (bool, error) { - var member_count int64 - result := orm. - Table("members"). - Where("plan_id=? AND name=?", p.ID, name). - Count(&member_count).Error - if result != nil { - return false, result +func (p *Plan) HasNonUser(db *sql.DB, name string) (bool, error) { + var member_count int + row := db.QueryRow("SELECT count(1) FROM members WHERE plan_id=? AND name=?", p.Id, name) + if err := row.Scan(&member_count); err != nil { + return false, nil } return member_count == 1, nil } -func (p *Plan) AddMember(orm *gorm.DB, new_member *Member) error { +func (p *Plan) AddMember(db *sql.DB, new_member *Member) error { if new_member == nil { return errors.New("Member is nil") } - new_member.PlanID = p.ID - if new_member.Type == "non-user" { - if new_member.Name == "" { - return errors.New("name required for non user") - } - found, err := p.HasNonUser(orm, new_member.Name) + new_member.PlanId = p.Id + if new_member.Name != "" { + found, err := p.HasNonUser(db, new_member.Name) if err != nil { return nil } if found { return errors.New("Non user name taken") } - return orm.Create(&new_member).Error - } else if new_member.Type == "user" { - user, err := GetUser(orm, new_member.UserID) + _, err = db.Exec("INSERT INTO members(plan_id, name) VALUES (?,?)", new_member.PlanId, new_member.Name) + return err + } else if *new_member.UserId != "" { + user, err := UserGet(db, *new_member.UserId) if err != nil { return err } - found, err := p.IsMember(orm, &user) + found, err := p.IsMember(db, &user) if err != nil { return nil } @@ -121,8 +118,9 @@ func (p *Plan) AddMember(orm *gorm.DB, new_member *Member) error { return errors.New("User already is member") } - return orm.Create(&new_member).Error + _, err = db.Exec("INSERT INTO members(plan_id, user_id) VALUES (?,?)", new_member.PlanId, new_member.UserId) + return err } else { - return errors.New("Invalid type for user") + return errors.New("Member object requires one of Name or UserID to be filled") } } diff --git a/core/polls.go b/core/polls.go index 564d0b5..7401c30 100644 --- a/core/polls.go +++ b/core/polls.go @@ -1,43 +1,59 @@ package core import ( + "database/sql" "fmt" + "slices" "strings" - - "gorm.io/gorm" ) -func GetPoll(orm *gorm.DB, user User, id uint) (*Poll, error) { - var poll Poll = Poll{ - ID: id, +func (p *Poll) Create(db *sql.DB) error { + row := db.QueryRow( + "INSERT INTO polls(plan_id, options) VALUES(?,?) RETURNING id", + p.PlanId, + p.Options, + ) + if err := row.Scan(&p.Id); err != nil { + return fmt.Errorf("Unable to create new poll for plan %d: %w", p.PlanId, err) } + return nil +} - if result := orm.Take(&poll); result.Error != nil { - return nil, result.Error +func PollGet(db *sql.DB, user User, id uint) (*Poll, error) { + var poll Poll + row := db.QueryRow("SELECT id,plan_id,options FROM polls WHERE id=?", id) + if err := row.Scan(&poll.Id, &poll.PlanId, &poll.Options); err != nil { + return nil, err } - fmt.Printf("%+v\n", poll.PlanID) - return &poll, nil } -func (p *Poll) SetVote(orm *gorm.DB, user User, option string) error { - found := false +func PollsList(db *sql.DB, plan_id int) ([]Poll, error) { + var polls []Poll + rows, err := db.Query("SELECT id,plan_id,options FROM polls WHERE plan_id=?", plan_id) + if err != nil { + return polls, fmt.Errorf("Unable to query polls for plan %d: %w", plan_id, err) + } + defer rows.Close() + for rows.Next() { + var poll Poll + if err := rows.Scan(&poll.Id, &poll.PlanId, &poll.Options); err != nil { + return nil, fmt.Errorf("Unable to scan polls for plan %d: %w", plan_id, err) + } + polls = append(polls, poll) + } + + return polls, nil +} + +func (p *Poll) SetVote(db *sql.DB, member Member, option string) error { options := strings.Split(p.Options, ",") - for _, opt := range options { - if opt == option { - found = true - break - } - } - if !found { - return ErrInvalidOption + if !slices.Contains(options, option) { + return fmt.Errorf("%s is not a valid option (%s): %w", option, options, ErrInvalidOption) } + _, err := db.Exec("INSERT INTO votes(poll_id,member_id,value) VALUES (?,?,?)", p.Id, member.Id, option) - if res := orm.Create(Vote{PollID: p.ID, UsernameID: user.Username, Value: option}); res.Error != nil { - return res.Error - } - - return nil + return err } diff --git a/core/sqlexecutor.go b/core/sqlexecutor.go new file mode 100644 index 0000000..89bfdb2 --- /dev/null +++ b/core/sqlexecutor.go @@ -0,0 +1,9 @@ +package core + +import "database/sql" + +type SqlExecutor interface { + Exec(query string, args ...interface{}) (sql.Result, error) + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row +} diff --git a/core/users.go b/core/users.go index b863ec6..9f0956f 100644 --- a/core/users.go +++ b/core/users.go @@ -1,31 +1,49 @@ package core import ( + "database/sql" "errors" - - "gorm.io/gorm" + "fmt" ) -func (u *User) ListPlans(orm *gorm.DB) ([]Plan, error) { +func (u *User) ListPlans(db *sql.DB) ([]Plan, error) { var plans []Plan - err := orm.Debug().Table("plans p"). - Select("p.*"). - Joins("JOIN members m ON m.plan_id=p.id"). - Where("m.user_id=?", u.Username). - Find(&plans) - return plans, err.Error + rows, err := db.Query( + "SELECT p.id,p.name,p.owner,p.description,p.join_code FROM plans p "+ + "JOIN members m ON m.plan_id=p.id "+ + "WHERE m.user_id=?", + u.Username, + ) + if err != nil { + return plans, fmt.Errorf( + "Unable to query plans for user %s: %w", + u.Username, + err, + ) + } + defer rows.Close() + for rows.Next() { + p := Plan{} + err = rows.Scan(&p.Id, &p.Name, &p.Owner, &p.Description, &p.JoinCode) + if err != nil { + return plans, fmt.Errorf( + "Unable to scan plan for user %s: %w", + u.Username, + err, + ) + } + plans = append(plans, p) + } + return plans, nil } -func (u *User) GetPlan(db *gorm.DB, plan_id uint) (Plan, error) { - var plan Plan = Plan{ - ID: plan_id, - } - result := db.Take(&plan) +func (u *User) GetPlan(db *sql.DB, plan_id uint) (*Plan, error) { + plan, err := PlanGet(db, plan_id) - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return plan, ErrNotFound - } else if result.Error != nil { - return plan, result.Error + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotFound + } else if err != nil { + return nil, err } if plan.Owner == u.Username { @@ -33,18 +51,43 @@ func (u *User) GetPlan(db *gorm.DB, plan_id uint) (Plan, error) { } isMember, err := plan.IsMember(db, u) - if !isMember || err != nil { - return plan, err + return nil, err } return plan, nil } -func GetUser(orm *gorm.DB, username string) (User, error) { - user := User{Username: username} - if err := orm.Take(&user).Error; err != nil { - return user, err +func (u *User) GetMemberFromPlan(db *sql.DB, p Plan) (*Member, error) { + m := Member{} + row := db.QueryRow( + "SELECT id,plan_id,name,user_id FROM members WHERE user_id=? AND plan_id=?", + u.Username, + p.Id, + ) + if err := row.Scan(&m.Id, &m.PlanId, &m.Name, &m.UserId); err != nil { + return nil, fmt.Errorf( + "Unable to get member from user %s and plan %d: %w", + u.Username, + p.Id, + err, + ) } - return user, nil + return &m, nil +} + +func UserGet(db *sql.DB, username string) (User, error) { + var user User + row := db.QueryRow("SELECT username, password FROM users WHERE username=?", username) + err := row.Scan(&user.Username, &user.Password) + return user, err +} + +func (u *User) Create(db *sql.DB) error { + _, err := db.Exec( + "INSERT INTO users(username, password) VALUES (?,?)", + u.Username, + u.Password, + ) + return fmt.Errorf("Unable to create user %s: %w", u.Username, err) } diff --git a/core/votes.go b/core/votes.go new file mode 100644 index 0000000..81ddd0b --- /dev/null +++ b/core/votes.go @@ -0,0 +1,23 @@ +package core + +import ( + "database/sql" + "fmt" +) + +func VotesList(db *sql.DB, poll_id int) ([]Vote, error) { + votes := []Vote{} + rows, err := db.Query("SELECT poll_id,member_id,value FROM polls WHERE poll_id=?", poll_id) + if err != nil { + return votes, fmt.Errorf("Unable to get votes for poll %d: %w", poll_id, err) + } + defer rows.Close() + for rows.Next() { + v := Vote{} + err = rows.Scan(&v.PollId, &v.MemberId, &v.Value) + if err != nil { + return votes, fmt.Errorf("Unable to scan vote for poll %d: %w", poll_id, err) + } + } + return votes, nil +} diff --git a/db.go b/db.go index 1f49809..c90bf34 100644 --- a/db.go +++ b/db.go @@ -1,45 +1,33 @@ package main import ( - "gorm.io/driver/sqlite" - "gorm.io/gorm" + "database/sql" "log" "os" - . "planner/core" + + _ "github.com/mattn/go-sqlite3" ) -func bootstrapDatabase() *gorm.DB { +func bootstrapDatabase() *sql.DB { fi, err := os.Stat("./db.sqlite") - var db *gorm.DB + var db *sql.DB if err != nil { if os.IsNotExist(err) { - db, err := gorm.Open(sqlite.Open("./db.sqlite"), &gorm.Config{}) + db, err := sql.Open("sqlite3", "./db.sqlite") if err != nil { log.Fatal(err) return nil } - db.AutoMigrate(&User{}, &Member{}, &Plan{}, &Poll{}, &Vote{}) - - //var tables = [...]struct { - // key string - // query string - //}{ - // {key: "users", query: "CREATE TABLE users(username STRING PRIMARY KEY, password STRING)"}, - // {key: "plans", query: "CREATE TABLE plans(id INTEGER PRIMARY KEY AUTOINCREMENT, name STRING, owner STRING, FOREIGN KEY(owner) REFERENCES users(username))"}, - // {key: "plan_user_relations", query: "CREATE TABLE plan_user_relations(username STRING, plan INTEGER, PRIMARY KEY(username, plan), FOREIGN KEY username REFERENCES user(username), FOREIGN KEY plan REFERENCES plans(id))"}, - // {key: "polls", query: "CREATE TABLE polls(id INTEGER PRIMARY KEY AUTOINCREMENT, plan INTEGER, name STRING, options JSON, FOREIGN KEY plan REFERENCES plans(id))"}, - // {key: "votes", query: "CREATE TABLE votes(id INTEGER, poll INTEGER, user STRING, value JSON, FOREIGN KEY poll REFERENCES polls(id), FOREIGN KEY user REFERENCES user(username))"}, - //} - - //for _, table := range tables { - // _, err = db.Exec(table.query) - // if err != nil { - // log.Fatal("Failed to create " + table.key + " table") - // log.Fatal(err) - // return false - // } - //} + file, err := os.ReadFile("./base.sql") + if err != nil { + log.Fatal("Unable to open base.sql", err) + } + contents := string(file) + _, err = db.Exec(contents) + if err != nil { + log.Fatal("Unable to create base db", err) + } return db } else { log.Fatal(err) @@ -52,7 +40,7 @@ func bootstrapDatabase() *gorm.DB { return nil } - db, err = gorm.Open(sqlite.Open("./db.sqlite"), &gorm.Config{}) + db, err = sql.Open("sqlite3", "./db.sqlite") if err != nil { log.Fatal(err) diff --git a/go.mod b/go.mod index 581ef33..c36d2e7 100644 --- a/go.mod +++ b/go.mod @@ -4,9 +4,8 @@ go 1.23.0 require ( github.com/gin-gonic/gin v1.10.0 - github.com/mattn/go-sqlite3 v1.14.23 - gorm.io/driver/sqlite v1.5.6 - gorm.io/gorm v1.25.12 + github.com/mattn/go-sqlite3 v1.14.27 + github.com/shopspring/decimal v1.4.0 ) require ( @@ -20,8 +19,6 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.22.0 // indirect github.com/goccy/go-json v0.10.3 // indirect - github.com/jinzhu/inflection v1.0.0 // indirect - github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/leodido/go-urn v1.4.0 // indirect diff --git a/go.sum b/go.sum index 45e3183..f23107e 100644 --- a/go.sum +++ b/go.sum @@ -45,6 +45,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.23 h1:gbShiuAP1W5j9UOksQ06aiiqPMxYecovVGwmTxWtuw0= github.com/mattn/go-sqlite3 v1.14.23/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-sqlite3 v1.14.27 h1:drZCnuvf37yPfs95E5jd9s3XhdVWLal+6BOK6qrv6IU= +github.com/mattn/go-sqlite3 v1.14.27/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -54,6 +56,8 @@ github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNH github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/planner.go b/planner.go index f30a018..907eed2 100644 --- a/planner.go +++ b/planner.go @@ -12,17 +12,16 @@ import ( func main() { fmt.Println("Opening database db.sqlite") - orm := bootstrapDatabase() + db := bootstrapDatabase() - if orm == nil { + if db == nil { log.Fatal("Failed to init database") return } - db, _ := orm.DB() defer db.Close() r := gin.Default() - apis.BindAPIs(r, orm) + apis.BindAPIs(r, db) r.Run() }