diff --git a/apis/base.go b/apis/base.go index a577c85..bbf34da 100644 --- a/apis/base.go +++ b/apis/base.go @@ -1,16 +1,16 @@ package apis import ( - "database/sql" "errors" "net/http" "github.com/gin-gonic/gin" + "gorm.io/gorm" ) -var db *sql.DB +var db *gorm.DB -func BindAPIs(r *gin.Engine, cfg_db *sql.DB) error { +func BindAPIs(r *gin.Engine, cfg_db *gorm.DB) error { if cfg_db == nil { return errors.New("Database is null") } @@ -23,6 +23,5 @@ func BindAPIs(r *gin.Engine, cfg_db *sql.DB) error { bindPlanAPIs(r) bindPollAPIs(r) bindUserAPIs(r) - bindExpensesAPIs(r) return nil } diff --git a/apis/expenses.go b/apis/expenses.go deleted file mode 100644 index 68a7685..0000000 --- a/apis/expenses.go +++ /dev/null @@ -1,230 +0,0 @@ -package apis - -import ( - "net/http" - "planner/core" - "slices" - "strconv" - - "github.com/gin-gonic/gin" - "github.com/shopspring/decimal" -) - -type ExpenseDef struct { - Payer string `json:"payer"` - Type string `json:"type"` - Amount decimal.Decimal `json:"amount"` -} - -func createExpense(c *gin.Context) { - u := extractUser(db, c) - if u == nil { - c.Status(http.StatusUnauthorized) - return - } - - plan_id, err := strconv.ParseInt(c.Param("id"), 10, 32) - if err != nil || plan_id < 0 { - c.Status(http.StatusBadRequest) - return - } - plan, err := u.GetPlan(db, uint(plan_id)) - if err != nil { - c.Status(http.StatusBadRequest) - return - } - - members, err := plan.GetAllUsers(db) - if err != nil { - c.Status(http.StatusInternalServerError) - return - } - var exp ExpenseDef - if err := c.ShouldBind(&exp); err != nil { - c.Status(http.StatusBadRequest) - return - } - - var compare func(m core.Member) bool - - if exp.Type == "user" { - compare = func(m core.Member) bool { - return *m.UserId == exp.Payer - } - } else if exp.Type == "non-user" { - compare = func(m core.Member) bool { - return m.Name == exp.Payer - } - } else { - c.String(http.StatusBadRequest, "Invalid member type") - return - } - idx := slices.IndexFunc(members, compare) - if idx != -1 { - expense, err := core.ExpenseCreate(db, *plan, members[idx], exp.Amount) - if err != nil { - c.Status(http.StatusInternalServerError) - return - } - c.JSON(http.StatusOK, expense) - return - } - c.String(http.StatusBadRequest, "Unable to found member") -} - -func listExpenses(c *gin.Context) { - u := extractUser(db, c) - if u == nil { - c.Status(http.StatusUnauthorized) - return - } - - plan_id, err := strconv.ParseInt(c.Param("id"), 10, 32) - if err != nil || plan_id < 0 { - c.Status(http.StatusBadRequest) - return - } - plan, err := u.GetPlan(db, uint(plan_id)) - if err != nil { - c.Status(http.StatusBadRequest) - return - } - expenses, err := core.ExpensesList(db, *plan) - if err != nil { - c.Status(http.StatusInternalServerError) - return - } - c.JSON(http.StatusOK, expenses) -} - -func getExpense(c *gin.Context) { - u := extractUser(db, c) - if u == nil { - c.Status(http.StatusUnauthorized) - return - } - - expense_id, err := strconv.ParseInt(c.Param("id"), 10, 32) - if err != nil || expense_id < 0 { - c.Status(http.StatusBadRequest) - return - } - expense, err := core.ExpensesGet(db, uint(expense_id)) - if err != nil { - c.Status(http.StatusBadRequest) - return - } - _, err = u.GetPlan(db, expense.PlanId) - if err != nil { - c.Status(http.StatusBadRequest) - return - } - - c.JSON(http.StatusOK, expense) -} - -func deleteExpense(c *gin.Context) { - u := extractUser(db, c) - if u == nil { - c.Status(http.StatusUnauthorized) - return - } - - expense_id, err := strconv.ParseInt(c.Param("id"), 10, 32) - if err != nil || expense_id < 0 { - c.Status(http.StatusBadRequest) - return - } - expense, err := core.ExpensesGet(db, uint(expense_id)) - if err != nil { - c.Status(http.StatusBadRequest) - return - } - _, err = u.GetPlan(db, expense.PlanId) - if err != nil { - c.Status(http.StatusBadRequest) - return - } - - if err = expense.Delete(db); err != nil { - c.Status(http.StatusInternalServerError) - return - } - - c.Status(http.StatusOK) -} - -func getExpenseDebts(c *gin.Context) { - u := extractUser(db, c) - if u == nil { - c.Status(http.StatusUnauthorized) - return - } - - expense_id, err := strconv.ParseInt(c.Param("id"), 10, 32) - if err != nil || expense_id < 0 { - c.Status(http.StatusBadRequest) - return - } - expense, err := core.ExpensesGet(db, uint(expense_id)) - if err != nil { - c.Status(http.StatusBadRequest) - return - } - _, err = u.GetPlan(db, expense.PlanId) - if err != nil { - c.Status(http.StatusBadRequest) - return - } - - debts, err := expense.GetDebt(db) - if err != nil { - c.Status(http.StatusInternalServerError) - return - } - - c.JSON(http.StatusOK, debts) -} - -func setExpenseDebts(c *gin.Context) { - u := extractUser(db, c) - if u == nil { - c.Status(http.StatusUnauthorized) - return - } - - expense_id, err := strconv.ParseInt(c.Param("id"), 10, 32) - if err != nil || expense_id < 0 { - c.Status(http.StatusBadRequest) - return - } - expense, err := core.ExpensesGet(db, uint(expense_id)) - if err != nil { - c.Status(http.StatusBadRequest) - return - } - _, err = u.GetPlan(db, expense.PlanId) - if err != nil { - c.Status(http.StatusBadRequest) - return - } - - var debtSpec []core.DebtSpec - if err := c.ShouldBind(&debtSpec); err != nil { - c.Status(http.StatusBadRequest) - return - } - - //expense.SetDebt(db, debtSpec) - - c.JSON(http.StatusOK, debtSpec) -} - -func bindExpensesAPIs(r *gin.Engine) { - r.POST("/plans/:id/expenses", createExpense) - r.GET("/plans/:id/expenses", listExpenses) - r.GET("/expenses/:id", getExpense) - r.DELETE("/expenses/:id", deleteExpense) - r.GET("/expenses/:id/debts", getExpenseDebts) - r.POST("/expenses/:id/debts", setExpenseDebts) -} diff --git a/apis/extract.go b/apis/extract.go index 0a9cd86..7ed4f7d 100644 --- a/apis/extract.go +++ b/apis/extract.go @@ -1,22 +1,24 @@ package apis import ( - "database/sql" + "github.com/gin-gonic/gin" + "gorm.io/gorm" "net/http" "planner/core" - - "github.com/gin-gonic/gin" ) -func extractUser(orm *sql.DB, c *gin.Context) *core.User { +func extractUser(orm *gorm.DB, c *gin.Context) *core.User { username, _, ok := c.Request.BasicAuth() if !ok { c.Status(http.StatusUnauthorized) return nil } + u := core.User{ + Username: username, + } - u, err := core.UserGet(db, username) - if err != nil { + result := orm.Take(&u) + if result.Error != nil { c.String(http.StatusNotFound, "Unable to find user "+username) return nil } diff --git a/apis/plans.go b/apis/plans.go index 7430277..2b40112 100644 --- a/apis/plans.go +++ b/apis/plans.go @@ -8,7 +8,6 @@ import ( "github.com/gin-gonic/gin" - "planner/core" . "planner/core" ) @@ -23,7 +22,7 @@ func createPlan(c *gin.Context) { } c.Bind(&plan_req) - _, err := PlanCreate(db, u, plan_req.Name) + _, err := CreatePlan(db, u, plan_req.Name) if err != nil { c.JSON(http.StatusInternalServerError, err) @@ -120,7 +119,7 @@ func addPlanMember(c *gin.Context) { return } - err = plan.AddMember(db, &Member{Name: new_member.Name}) + err = plan.AddMember(db, &Member{Name: new_member.Name, Type: "non-user"}) if err == nil { c.JSON(http.StatusOK, new_member) @@ -145,7 +144,7 @@ func joinPlan(c *gin.Context) { c.Status(http.StatusBadRequest) return } - plan, err := PlanGet(db, uint(plan_id)) + plan, err := GetPlan(db, uint(plan_id)) if err != nil || plan == nil { c.Status(http.StatusInternalServerError) return @@ -172,7 +171,7 @@ func joinPlan(c *gin.Context) { c.String(http.StatusConflict, "User already a member") return } - plan.AddMember(db, &Member{UserId: &user.Username}) + plan.AddMember(db, &Member{Type: "user", UserID: user.Username}) c.Status(http.StatusOK) } @@ -213,10 +212,10 @@ func createPlanPoll(c *gin.Context) { } poll := Poll{ - PlanId: plan.Id, + PlanID: plan.ID, Options: poll_opts.Options, } - poll.Create(db) + db.Create(&poll) c.JSON(http.StatusCreated, poll) } @@ -237,11 +236,8 @@ func listPlanPolls(c *gin.Context) { return } - polls, err := core.PollsList(db, int(params.Id)) - if err != nil { - c.String(http.StatusInternalServerError, err.Error()) - return - } + var polls []Poll + db.Where("plan_id = ?", params.Id).Find(&polls) c.JSON(http.StatusOK, polls) } diff --git a/apis/polls.go b/apis/polls.go index fa91666..93d43d6 100644 --- a/apis/polls.go +++ b/apis/polls.go @@ -25,7 +25,7 @@ func getPoll(c *gin.Context) { } fmt.Println(params) - poll, _ := core.PollGet(db, *user, params.PollId) + poll, _ := core.GetPoll(db, *user, params.PollId) c.JSON(http.StatusOK, poll) } @@ -45,12 +45,9 @@ func getPollVotes(c *gin.Context) { return } - votes, err := core.VotesList(db, int(params.PollId)) - if err != nil { - c.String(http.StatusInternalServerError, err.Error()) - return - } - c.JSON(http.StatusOK, votes) + var votes []core.Vote + db.Where("poll_id = ?", params.PollId).Find(&votes) + c.JSON(http.StatusOK, &votes) } func pollVote(c *gin.Context) { @@ -69,7 +66,7 @@ func pollVote(c *gin.Context) { return } - poll, err := core.PollGet(db, *user, path_params.PollId) + poll, err := core.GetPoll(db, *user, path_params.PollId) if err != nil { c.String(http.StatusInternalServerError, err.Error()) return @@ -80,19 +77,7 @@ func pollVote(c *gin.Context) { } c.Bind(&vote_params) - plan, err := core.PlanGet(db, poll.PlanId) - if err != nil { - c.String(http.StatusInternalServerError, "Unable to find plan: "+err.Error()) - return - } - - member, err := user.GetMemberFromPlan(db, *plan) - if err != nil { - c.String(http.StatusInternalServerError, "Unable to find member: "+err.Error()) - return - } - - if err := poll.SetVote(db, *member, vote_params.Vote); err != nil { + if err := poll.SetVote(db, *user, vote_params.Vote); err != nil { c.String(http.StatusBadRequest, err.Error()) } diff --git a/apis/users.go b/apis/users.go index 3fd802c..16ad82a 100644 --- a/apis/users.go +++ b/apis/users.go @@ -6,7 +6,7 @@ import ( "github.com/gin-gonic/gin" - "planner/core" + . "planner/core" ) func getUserByName(c *gin.Context) { @@ -14,23 +14,19 @@ func getUserByName(c *gin.Context) { Name string `fdb:"name"` } if c.ShouldBind(&q) == nil { - user, err := core.UserGet(db, q.Name) - if err != nil { - c.String(http.StatusInternalServerError, "Unable to get user: "+err.Error()) - return + user := User{ + Username: q.Name, } + db.Take(&user) + fmt.Println(user) c.JSON(http.StatusOK, user) } } func createUser(c *gin.Context) { - var u core.User + var u User if c.ShouldBind(&u) == nil { - err := u.Create(db) - if err != nil { - c.String(http.StatusInternalServerError, "Unable to create user: "+err.Error()) - return - } + db.Create(&u) c.Status(http.StatusCreated) } else { fmt.Print("Could not bind model") @@ -47,13 +43,12 @@ func login(c *gin.Context) { if q.Username == "" { c.String(http.StatusBadRequest, "Login data is null") } else { - user, err := core.UserGet(db, q.Username) - if err != nil { - c.String(http.StatusInternalServerError, "Unable to get user: "+err.Error()) - return + user := User{ + Username: q.Username, } + db.Take(&user) if user.Password == q.Password { - c.JSON(http.StatusOK, gin.H{"username": user.Username}) + c.JSON(http.StatusOK, map[string]string{"username": user.Username}) } else { c.Status(http.StatusForbidden) } diff --git a/base.sql b/base.sql deleted file mode 100644 index ec3cfb8..0000000 --- a/base.sql +++ /dev/null @@ -1,47 +0,0 @@ -CREATE TABLE users ( - username TEXT PRIMARY KEY, - password TEXT NOT NULL -); - -CREATE TABLE plans ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL, - owner TEXT REFERENCES users(username) NOT NULL, - description TEXT DEFAULT '', - join_code TEXT NOT NULL -); - -CREATE TABLE members ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - plan_id INTEGER REFERENCES plans(id) NOT NULL, - name TEXT NOT NULL, - user_id TEXT REFERENCES users(username) -); - -CREATE TABLE polls ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - plan_id INTEGER REFERENCES plans(id) NOT NULL, - options TEXT DEFAULT '' -); - -CREATE TABLE votes ( - poll_id INTEGER REFERENCES polls(id), - member_id INTEGER REFERENCES members(id), - value text, - PRIMARY KEY (poll_id,member_id) -); - -CREATE TABLE expenses ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - plan_id INTEGER REFERENCES plans(id) NOT NULL, - payer_id INTEGER REFERENCES members(id) NOT NULL, - amount DECIMAL NOT NULL -); - -CREATE TABLE debts ( - expense_id INTEGER REFERENCES expenses(id) NOT NULL, - debtor_id INTEGER REFERENCES members(id) NOT NULL, - amount DECIMAL NOT NULL, - paid DECIMAL, - PRIMARY KEY (expense_id,debtor_id) -); diff --git a/core/debts.go b/core/debts.go deleted file mode 100644 index 5fcbc8b..0000000 --- a/core/debts.go +++ /dev/null @@ -1,12 +0,0 @@ -package core - -func (d *Debt) Create(db SqlExecutor) error { - _, err := db.Exec( - "INSERT INTO debts(expense_id, debtor_id, amount, paid) VALUES (?,?,?,?)", - d.ExpenseId, - d.DebtorId, - d.Amount, - d.Paid, - ) - return err -} diff --git a/core/expenses.go b/core/expenses.go deleted file mode 100644 index fbdae05..0000000 --- a/core/expenses.go +++ /dev/null @@ -1,196 +0,0 @@ -package core - -import ( - "database/sql" - "encoding/json" - "errors" - "fmt" - - "github.com/shopspring/decimal" -) - -type DebtType int - -const ( - ProportionalDebt DebtType = iota - AbsoluteDebt = iota -) - -func (dt DebtType) String() string { - return [...]string{"proportional", "absolute"}[dt] -} - -func (dt DebtType) MarshalJSON() ([]byte, error) { - return json.Marshal(dt.String()) -} - -func (dt *DebtType) UnmarshalJSON(data []byte) error { - var s string - if err := json.Unmarshal(data, &s); err != nil { - return nil - } - if s == "proportional" { - *dt = ProportionalDebt - } else if s == "absolute" { - *dt = AbsoluteDebt - } else { - return errors.New("Invalid value for debt type") - } - return nil -} - -type DebtSpec struct { - Member Member `json:"member"` - Amount decimal.Decimal `json:"amount"` - DebtType DebtType `json:"debt_type"` -} - -func ExpenseCreate(db *sql.DB, plan Plan, payer Member, amount decimal.Decimal) (*Expense, error) { - expense := new(Expense) - expense.PlanId = plan.Id - expense.PayerId = payer.Id - expense.Amount = amount - row := db.QueryRow( - "INSERT INTO expenses(plan_id,payer_id,amount) VALUES (?,?,?) RETURNING id", - expense.PlanId, - expense.PayerId, - expense.Amount, - ) - if err := row.Scan(&expense.Id); err != nil { - return nil, err - } - return expense, nil -} - -func (e *Expense) GetDebt(db *sql.DB) ([]Debt, error) { - var debts []Debt - rows, err := db.Query("SELECT expense_id,debtor_id,amount,paid FROM debts WHERE expense_id=?") - if err != nil { - return debts, err - } - defer rows.Close() - for rows.Next() { - var d Debt - err = rows.Scan(&d.ExpenseId, &d.DebtorId, &d.Amount, &d.Paid) - if err != nil { - return []Debt{}, err - } - debts = append(debts, d) - } - return debts, nil -} - -func (e *Expense) SetDebt(db *sql.DB, debtSpec []DebtSpec) error { - abs_paid := decimal.Decimal{} - debts := make([]Debt, 0) - var prop_payers int64 = 0 - for _, debt := range debtSpec { - if debt.DebtType == ProportionalDebt { - prop_payers += 1 - } else { - debts = append(debts, Debt{ - Amount: debt.Amount, - Paid: decimal.Decimal{}, - ExpenseId: e.Id, - DebtorId: debt.Member.Id, - }) - abs_paid.Add(debt.Amount) - } - } - if abs_paid.Cmp(e.Amount) > 0 { - return errors.New("Absolute pay amount is larger than debt") - } - - prop_debt := e.Amount.Sub(abs_paid).DivRound(decimal.NewFromInt(prop_payers), 2) - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - for _, debt := range debts { - if err := debt.Create(tx); err != nil { - return err - } - } - - for _, debt := range debtSpec { - if debt.DebtType != ProportionalDebt { - continue - } - debtObj := Debt{ - Amount: prop_debt, - Paid: decimal.Decimal{}, - ExpenseId: e.Id, - DebtorId: debt.Member.Id, - } - if err := debtObj.Create(tx); err != nil { - return err - } - } - return tx.Commit() -} - -func (e *Expense) Delete(db *sql.DB) error { - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - _, err = tx.Exec("DELETE FROM debts WHERE expense_id=?", e.Id) - if err != nil { - return fmt.Errorf("Unable to delete debts for expense %d: %w", e.Id, err) - } - _, err = tx.Exec("DELETE FROM expenses WHERE id=?", e.Id) - if err != nil { - return fmt.Errorf("Unable to delete expense %d: %w", e.Id, err) - } - err = tx.Commit() - if err != nil { - return fmt.Errorf( - "Unable to commit transaction when deleting expense %d: %w", - e.Id, - err, - ) - } - return nil -} - -func ExpensesList(db *sql.DB, plan Plan) ([]Expense, error) { - var expenses []Expense - rows, err := db.Query( - "SELECT id,plan_id,payer_id,amount FROM expenses WHERE plan_id=$1", - plan.Id, - ) - if err != nil { - return expenses, fmt.Errorf( - "Unable to query for expenses in plan %d: %w", - plan.Id, - err, - ) - } - defer rows.Close() - for rows.Next() { - e := Expense{} - err = rows.Scan(&e.Id, e.PlanId, e.PayerId, e.Amount) - if err != nil { - return expenses, fmt.Errorf( - "Unable to scan expense during list: %w", - err, - ) - } - expenses = append(expenses, e) - } - return expenses, nil -} - -func ExpensesGet(db *sql.DB, expense_id uint) (*Expense, error) { - e := Expense{} - row := db.QueryRow( - "SELECT id,plan_id,payer_id,amount FROM expenses WHERE id=?", - expense_id, - ) - if err := row.Scan(&e.Id, &e.PlanId, &e.PayerId, &e.Amount); err != nil { - return nil, fmt.Errorf("Unable to get expense %d: %w", expense_id, err) - } - return &e, nil -} diff --git a/core/models.go b/core/models.go index f5fe490..53268c0 100644 --- a/core/models.go +++ b/core/models.go @@ -1,54 +1,46 @@ package core -import "github.com/shopspring/decimal" - type User struct { - Username string `json:"username"` - Password string `json:"password"` + Username string `gorm:"primaryKey" json:"username"` + Password string `json:"password"` + OwnedPlans []Plan `gorm:"foreignKey:Owner;references:Username" json:"-"` + MemberOf []Member `json:"-"` + Votes []Vote `gorm:"foreignKey:UsernameID" json:"-"` } type Member struct { - Id uint `json:"-"` - PlanId uint `json:"-"` - Name string `json:"name,omitempty"` - UserId *string `json:"username,omitempty"` + ID uint `gorm:"primaryKey;autoIncrement:true" json:"-"` + PlanID uint `json:"-"` + Plan Plan `json:"-"` + Type string `gorm:"check:type in ('user','non-user')" json:"type"` + Name string `gorm:"check:type=='member' OR name IS NOT NULL" json:"name"` + UserID string `json:"username"` + User User `gorm:"foreignKey:UserID" json:"-"` } // CREATE TABLE plans(id INTEGER PRIMARY KEY AUTOINCREMENT, name STRING, owner STRING, FOREIGN KEY(owner) REFERENCES users(username)) // CREATE TABLE plan_user_relations(username STRING, plan INTEGER, PRIMARY KEY(username, plan), FOREIGN KEY username REFERENCES user(username), FOREIGN KEY plan REFERENCES plans(id)) type Plan struct { - Id uint `json:"id"` - Name string `json:"name"` - Owner string `json:"owner"` - Description string `json:"description"` - JoinCode string `json:"join_code,omitempty"` -} - -type Expense struct { - Id uint `json:"id"` - PlanId uint `json:"-"` - PayerId uint `json:"-"` - Amount decimal.Decimal `json:"amount"` -} - -type Debt struct { - Id uint - ExpenseId uint `json:"-"` - DebtorId uint `json:"-"` - Amount decimal.Decimal - Paid decimal.Decimal + ID uint `gorm:"primaryKey;autoIncrement:true" json:"id"` + Name string `json:"name"` + Owner string `json:"owner"` + Description string `json:"description"` + JoinCode string `gorm:"not null" json:"join_code,omitempty"` + Members []Member `json:"-"` + Polls []Poll `gorm:"foreignKey:PlanID;references:ID" json:"-"` } // CREATE TABLE polls(id INTEGER PRIMARY KEY AUTOINCREMENT, plan INTEGER, name STRING, options JSON, FOREIGN KEY plan REFERENCES plans(id)) type Poll struct { - Id uint `json:"id"` - PlanId uint `json:"-"` + ID uint `gorm:"primaryKey;autoIncrement:true" json:"id"` + PlanID uint `json:"-"` Options string `json:"options"` + Votes []Vote `gorm:"foreignKey:PollID;references:ID" json:"-"` } // CREATE TABLE votes(id INTEGER, poll INTEGER, user STRING, value JSON, FOREIGN KEY poll REFERENCES polls(id), FOREIGN KEY user REFERENCES user(username)) type Vote struct { - PollId uint `json:"-"` - MemberId uint `json:"member_id"` - Value string `json:"value"` + PollID uint `gorm:"primaryKey" json:"-"` + UsernameID string `gorm:"primaryKey" json:"username_id"` + Value string `json:"value"` } diff --git a/core/plans.go b/core/plans.go index ac56514..32ad597 100644 --- a/core/plans.go +++ b/core/plans.go @@ -2,13 +2,13 @@ package core import ( "crypto/rand" - "database/sql" "encoding/base64" "errors" - "fmt" + + "gorm.io/gorm" ) -func PlanCreate(db *sql.DB, user *User, name string) (*Plan, error) { +func CreatePlan(db *gorm.DB, user *User, name string) (*Plan, error) { join_code := make([]byte, 32) _, err := rand.Read(join_code) if err != nil { @@ -16,101 +16,104 @@ func PlanCreate(db *sql.DB, user *User, name string) (*Plan, error) { } var plan Plan = Plan{ - Name: name, - Owner: user.Username, + Name: name, + Owner: user.Username, + Members: []Member{ + { + UserID: user.Username, + Type: "user", + }, + }, JoinCode: base64.URLEncoding.EncodeToString(join_code), } - row := db.QueryRow( - "INSERT INTO plans(name, owner, description, join_jode) VALUES (?, ?, '', ?) RETURNING id", - plan.Name, - plan.Owner, - plan.JoinCode, - ) - if err := row.Scan(&plan.Id); err != nil { - return nil, err + result := db.Create(&plan) + if result.Error != nil { + return nil, result.Error } return &plan, nil } -func PlanGet(db *sql.DB, id uint) (*Plan, error) { - plan := Plan{} - row := db.QueryRow("SELECT name,owner,description,join_code FROM plans WHERE id=?", id) - if err := row.Scan(&plan.Name, &plan.Owner, &plan.Description, &plan.JoinCode); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, ErrNotFound - } - return nil, fmt.Errorf("Unexpected database error: %w", err) +func GetPlan(orm *gorm.DB, id uint) (*Plan, error) { + var plan Plan = Plan{ + ID: id, } + result := orm.Take(&plan) + + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } else if result.Error != nil { + return nil, result.Error + } + return &plan, nil } -func (p *Plan) GetAllUsers(db *sql.DB) ([]Member, error) { - members := []Member{} - rows, err := db.Query("SELECT id,plan_id,name,user_id FROM members WHERE plan_id=?", p.Id) - if err != nil { - return members, err - } - defer rows.Close() - for rows.Next() { - var m Member - rows.Scan(&m.Id, &m.PlanId, &m.Name, &m.UserId) - members = append(members, m) - } +func (p *Plan) GetAllUsers(orm *gorm.DB) ([]Member, error) { + var members []Member + err := orm.Model(p).Association("Members").Find(&members) return members, err } -func (p *Plan) IsMember(db *sql.DB, u *User) (bool, error) { - var member_count int - row := db.QueryRow("SELECT count(1) FROM members WHERE plan_id=? AND user_id=?", p.Id, u.Username) - if err := row.Scan(&member_count); err != nil { - return false, nil +func (p *Plan) IsMember(orm *gorm.DB, u *User) (bool, error) { + var member_count int64 + result := orm. + Table("members"). + Where("user_id=? AND plan_id=?", u.Username, p.ID). + Count(&member_count).Error + if result != nil { + return false, ErrNotMember } return member_count == 1, nil } -func (p *Plan) GetMember(db *sql.DB, u *User) (*Member, error) { +func (p *Plan) GetMember(orm *gorm.DB, u *User) (*Member, error) { var m Member - row := db.QueryRow("SELECT id,plan_id,name,user_id FROM members WHERE plan_id=? AND user_id=?", p.Id, u.Username) - if err := row.Scan(&m.Id, &m.PlanId, &m.Name, &m.UserId); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, ErrNotMember - } - return nil, fmt.Errorf("Unable to check if user %s is member in plan %d: %w", u.Username, p.Id, err) + result := orm. + Table("members"). + Where("user_id=? AND plan_id=?", u.Username, p.ID). + Take(&m).Error + if result != nil { + return nil, ErrNotMember } return &m, nil } -func (p *Plan) HasNonUser(db *sql.DB, name string) (bool, error) { - var member_count int - row := db.QueryRow("SELECT count(1) FROM members WHERE plan_id=? AND name=?", p.Id, name) - if err := row.Scan(&member_count); err != nil { - return false, nil +func (p *Plan) HasNonUser(orm *gorm.DB, name string) (bool, error) { + var member_count int64 + result := orm. + Table("members"). + Where("plan_id=? AND name=?", p.ID, name). + Count(&member_count).Error + if result != nil { + return false, result } return member_count == 1, nil } -func (p *Plan) AddMember(db *sql.DB, new_member *Member) error { +func (p *Plan) AddMember(orm *gorm.DB, new_member *Member) error { if new_member == nil { return errors.New("Member is nil") } - new_member.PlanId = p.Id - if new_member.Name != "" { - found, err := p.HasNonUser(db, new_member.Name) + new_member.PlanID = p.ID + if new_member.Type == "non-user" { + if new_member.Name == "" { + return errors.New("name required for non user") + } + found, err := p.HasNonUser(orm, new_member.Name) if err != nil { return nil } if found { return errors.New("Non user name taken") } - _, err = db.Exec("INSERT INTO members(plan_id, name) VALUES (?,?)", new_member.PlanId, new_member.Name) - return err - } else if *new_member.UserId != "" { - user, err := UserGet(db, *new_member.UserId) + return orm.Create(&new_member).Error + } else if new_member.Type == "user" { + user, err := GetUser(orm, new_member.UserID) if err != nil { return err } - found, err := p.IsMember(db, &user) + found, err := p.IsMember(orm, &user) if err != nil { return nil } @@ -118,9 +121,8 @@ func (p *Plan) AddMember(db *sql.DB, new_member *Member) error { return errors.New("User already is member") } - _, err = db.Exec("INSERT INTO members(plan_id, user_id) VALUES (?,?)", new_member.PlanId, new_member.UserId) - return err + return orm.Create(&new_member).Error } else { - return errors.New("Member object requires one of Name or UserID to be filled") + return errors.New("Invalid type for user") } } diff --git a/core/polls.go b/core/polls.go index 7401c30..564d0b5 100644 --- a/core/polls.go +++ b/core/polls.go @@ -1,59 +1,43 @@ package core import ( - "database/sql" "fmt" - "slices" "strings" + + "gorm.io/gorm" ) -func (p *Poll) Create(db *sql.DB) error { - row := db.QueryRow( - "INSERT INTO polls(plan_id, options) VALUES(?,?) RETURNING id", - p.PlanId, - p.Options, - ) - if err := row.Scan(&p.Id); err != nil { - return fmt.Errorf("Unable to create new poll for plan %d: %w", p.PlanId, err) +func GetPoll(orm *gorm.DB, user User, id uint) (*Poll, error) { + var poll Poll = Poll{ + ID: id, } - return nil -} -func PollGet(db *sql.DB, user User, id uint) (*Poll, error) { - var poll Poll - row := db.QueryRow("SELECT id,plan_id,options FROM polls WHERE id=?", id) - if err := row.Scan(&poll.Id, &poll.PlanId, &poll.Options); err != nil { - return nil, err + if result := orm.Take(&poll); result.Error != nil { + return nil, result.Error } + fmt.Printf("%+v\n", poll.PlanID) + return &poll, nil } -func PollsList(db *sql.DB, plan_id int) ([]Poll, error) { - var polls []Poll - rows, err := db.Query("SELECT id,plan_id,options FROM polls WHERE plan_id=?", plan_id) - if err != nil { - return polls, fmt.Errorf("Unable to query polls for plan %d: %w", plan_id, err) - } - defer rows.Close() - for rows.Next() { - var poll Poll - if err := rows.Scan(&poll.Id, &poll.PlanId, &poll.Options); err != nil { - return nil, fmt.Errorf("Unable to scan polls for plan %d: %w", plan_id, err) - } - polls = append(polls, poll) - } - - return polls, nil -} - -func (p *Poll) SetVote(db *sql.DB, member Member, option string) error { +func (p *Poll) SetVote(orm *gorm.DB, user User, option string) error { + found := false options := strings.Split(p.Options, ",") - if !slices.Contains(options, option) { - return fmt.Errorf("%s is not a valid option (%s): %w", option, options, ErrInvalidOption) + for _, opt := range options { + if opt == option { + found = true + break + } + } + if !found { + return ErrInvalidOption } - _, err := db.Exec("INSERT INTO votes(poll_id,member_id,value) VALUES (?,?,?)", p.Id, member.Id, option) - return err + if res := orm.Create(Vote{PollID: p.ID, UsernameID: user.Username, Value: option}); res.Error != nil { + return res.Error + } + + return nil } diff --git a/core/sqlexecutor.go b/core/sqlexecutor.go deleted file mode 100644 index 89bfdb2..0000000 --- a/core/sqlexecutor.go +++ /dev/null @@ -1,9 +0,0 @@ -package core - -import "database/sql" - -type SqlExecutor interface { - Exec(query string, args ...interface{}) (sql.Result, error) - Query(query string, args ...interface{}) (*sql.Rows, error) - QueryRow(query string, args ...interface{}) *sql.Row -} diff --git a/core/users.go b/core/users.go index 9f0956f..b863ec6 100644 --- a/core/users.go +++ b/core/users.go @@ -1,49 +1,31 @@ package core import ( - "database/sql" "errors" - "fmt" + + "gorm.io/gorm" ) -func (u *User) ListPlans(db *sql.DB) ([]Plan, error) { +func (u *User) ListPlans(orm *gorm.DB) ([]Plan, error) { var plans []Plan - rows, err := db.Query( - "SELECT p.id,p.name,p.owner,p.description,p.join_code FROM plans p "+ - "JOIN members m ON m.plan_id=p.id "+ - "WHERE m.user_id=?", - u.Username, - ) - if err != nil { - return plans, fmt.Errorf( - "Unable to query plans for user %s: %w", - u.Username, - err, - ) - } - defer rows.Close() - for rows.Next() { - p := Plan{} - err = rows.Scan(&p.Id, &p.Name, &p.Owner, &p.Description, &p.JoinCode) - if err != nil { - return plans, fmt.Errorf( - "Unable to scan plan for user %s: %w", - u.Username, - err, - ) - } - plans = append(plans, p) - } - return plans, nil + err := orm.Debug().Table("plans p"). + Select("p.*"). + Joins("JOIN members m ON m.plan_id=p.id"). + Where("m.user_id=?", u.Username). + Find(&plans) + return plans, err.Error } -func (u *User) GetPlan(db *sql.DB, plan_id uint) (*Plan, error) { - plan, err := PlanGet(db, plan_id) +func (u *User) GetPlan(db *gorm.DB, plan_id uint) (Plan, error) { + var plan Plan = Plan{ + ID: plan_id, + } + result := db.Take(&plan) - if errors.Is(err, sql.ErrNoRows) { - return nil, ErrNotFound - } else if err != nil { - return nil, err + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return plan, ErrNotFound + } else if result.Error != nil { + return plan, result.Error } if plan.Owner == u.Username { @@ -51,43 +33,18 @@ func (u *User) GetPlan(db *sql.DB, plan_id uint) (*Plan, error) { } isMember, err := plan.IsMember(db, u) + if !isMember || err != nil { - return nil, err + return plan, err } return plan, nil } -func (u *User) GetMemberFromPlan(db *sql.DB, p Plan) (*Member, error) { - m := Member{} - row := db.QueryRow( - "SELECT id,plan_id,name,user_id FROM members WHERE user_id=? AND plan_id=?", - u.Username, - p.Id, - ) - if err := row.Scan(&m.Id, &m.PlanId, &m.Name, &m.UserId); err != nil { - return nil, fmt.Errorf( - "Unable to get member from user %s and plan %d: %w", - u.Username, - p.Id, - err, - ) +func GetUser(orm *gorm.DB, username string) (User, error) { + user := User{Username: username} + if err := orm.Take(&user).Error; err != nil { + return user, err } - return &m, nil -} - -func UserGet(db *sql.DB, username string) (User, error) { - var user User - row := db.QueryRow("SELECT username, password FROM users WHERE username=?", username) - err := row.Scan(&user.Username, &user.Password) - return user, err -} - -func (u *User) Create(db *sql.DB) error { - _, err := db.Exec( - "INSERT INTO users(username, password) VALUES (?,?)", - u.Username, - u.Password, - ) - return fmt.Errorf("Unable to create user %s: %w", u.Username, err) + return user, nil } diff --git a/core/votes.go b/core/votes.go deleted file mode 100644 index 81ddd0b..0000000 --- a/core/votes.go +++ /dev/null @@ -1,23 +0,0 @@ -package core - -import ( - "database/sql" - "fmt" -) - -func VotesList(db *sql.DB, poll_id int) ([]Vote, error) { - votes := []Vote{} - rows, err := db.Query("SELECT poll_id,member_id,value FROM polls WHERE poll_id=?", poll_id) - if err != nil { - return votes, fmt.Errorf("Unable to get votes for poll %d: %w", poll_id, err) - } - defer rows.Close() - for rows.Next() { - v := Vote{} - err = rows.Scan(&v.PollId, &v.MemberId, &v.Value) - if err != nil { - return votes, fmt.Errorf("Unable to scan vote for poll %d: %w", poll_id, err) - } - } - return votes, nil -} diff --git a/db.go b/db.go index c90bf34..1f49809 100644 --- a/db.go +++ b/db.go @@ -1,33 +1,45 @@ package main import ( - "database/sql" + "gorm.io/driver/sqlite" + "gorm.io/gorm" "log" "os" - - _ "github.com/mattn/go-sqlite3" + . "planner/core" ) -func bootstrapDatabase() *sql.DB { +func bootstrapDatabase() *gorm.DB { fi, err := os.Stat("./db.sqlite") - var db *sql.DB + var db *gorm.DB if err != nil { if os.IsNotExist(err) { - db, err := sql.Open("sqlite3", "./db.sqlite") + db, err := gorm.Open(sqlite.Open("./db.sqlite"), &gorm.Config{}) if err != nil { log.Fatal(err) return nil } - file, err := os.ReadFile("./base.sql") - if err != nil { - log.Fatal("Unable to open base.sql", err) - } - contents := string(file) - _, err = db.Exec(contents) - if err != nil { - log.Fatal("Unable to create base db", err) - } + db.AutoMigrate(&User{}, &Member{}, &Plan{}, &Poll{}, &Vote{}) + + //var tables = [...]struct { + // key string + // query string + //}{ + // {key: "users", query: "CREATE TABLE users(username STRING PRIMARY KEY, password STRING)"}, + // {key: "plans", query: "CREATE TABLE plans(id INTEGER PRIMARY KEY AUTOINCREMENT, name STRING, owner STRING, FOREIGN KEY(owner) REFERENCES users(username))"}, + // {key: "plan_user_relations", query: "CREATE TABLE plan_user_relations(username STRING, plan INTEGER, PRIMARY KEY(username, plan), FOREIGN KEY username REFERENCES user(username), FOREIGN KEY plan REFERENCES plans(id))"}, + // {key: "polls", query: "CREATE TABLE polls(id INTEGER PRIMARY KEY AUTOINCREMENT, plan INTEGER, name STRING, options JSON, FOREIGN KEY plan REFERENCES plans(id))"}, + // {key: "votes", query: "CREATE TABLE votes(id INTEGER, poll INTEGER, user STRING, value JSON, FOREIGN KEY poll REFERENCES polls(id), FOREIGN KEY user REFERENCES user(username))"}, + //} + + //for _, table := range tables { + // _, err = db.Exec(table.query) + // if err != nil { + // log.Fatal("Failed to create " + table.key + " table") + // log.Fatal(err) + // return false + // } + //} return db } else { log.Fatal(err) @@ -40,7 +52,7 @@ func bootstrapDatabase() *sql.DB { return nil } - db, err = sql.Open("sqlite3", "./db.sqlite") + db, err = gorm.Open(sqlite.Open("./db.sqlite"), &gorm.Config{}) if err != nil { log.Fatal(err) diff --git a/go.mod b/go.mod index c36d2e7..581ef33 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,9 @@ go 1.23.0 require ( github.com/gin-gonic/gin v1.10.0 - github.com/mattn/go-sqlite3 v1.14.27 - github.com/shopspring/decimal v1.4.0 + github.com/mattn/go-sqlite3 v1.14.23 + gorm.io/driver/sqlite v1.5.6 + gorm.io/gorm v1.25.12 ) require ( @@ -19,6 +20,8 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.22.0 // indirect github.com/goccy/go-json v0.10.3 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/leodido/go-urn v1.4.0 // indirect diff --git a/go.sum b/go.sum index f23107e..45e3183 100644 --- a/go.sum +++ b/go.sum @@ -45,8 +45,6 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.23 h1:gbShiuAP1W5j9UOksQ06aiiqPMxYecovVGwmTxWtuw0= github.com/mattn/go-sqlite3 v1.14.23/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/mattn/go-sqlite3 v1.14.27 h1:drZCnuvf37yPfs95E5jd9s3XhdVWLal+6BOK6qrv6IU= -github.com/mattn/go-sqlite3 v1.14.27/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -56,8 +54,6 @@ github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNH github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= -github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/planner.go b/planner.go index 907eed2..f30a018 100644 --- a/planner.go +++ b/planner.go @@ -12,16 +12,17 @@ import ( func main() { fmt.Println("Opening database db.sqlite") - db := bootstrapDatabase() + orm := bootstrapDatabase() - if db == nil { + if orm == nil { log.Fatal("Failed to init database") return } + db, _ := orm.DB() defer db.Close() r := gin.Default() - apis.BindAPIs(r, db) + apis.BindAPIs(r, orm) r.Run() }