diff --git a/apis/base.go b/apis/base.go index bbf34da..a60e9e6 100644 --- a/apis/base.go +++ b/apis/base.go @@ -23,5 +23,6 @@ func BindAPIs(r *gin.Engine, cfg_db *gorm.DB) error { bindPlanAPIs(r) bindPollAPIs(r) bindUserAPIs(r) + bindExpensesAPIs(r) return nil } diff --git a/apis/expenses.go b/apis/expenses.go new file mode 100644 index 0000000..c38429f --- /dev/null +++ b/apis/expenses.go @@ -0,0 +1,231 @@ +package apis + +import ( + "net/http" + "planner/core" + "slices" + "strconv" + + "github.com/gin-gonic/gin" + "github.com/shopspring/decimal" +) + +type ExpenseDef struct { + Payer string `json:"payer"` + Type string `json:"type"` + Amount decimal.Decimal `json:"amount"` +} + +func createExpense(c *gin.Context) { + u := extractUser(db, c) + if u == nil { + c.Status(http.StatusUnauthorized) + return + } + + plan_id, err := strconv.ParseInt(c.Param("id"), 10, 32) + if err != nil || plan_id < 0 { + c.Status(http.StatusBadRequest) + return + } + plan, err := u.GetPlan(db, uint(plan_id)) + if err != nil { + c.Status(http.StatusBadRequest) + return + } + + members, err := plan.GetAllUsers(db) + if err != nil { + c.Status(http.StatusInternalServerError) + return + } + var exp ExpenseDef + if err := c.ShouldBind(&exp); err != nil { + c.Status(http.StatusBadRequest) + return + } + + var compare func(m core.Member) bool + + if exp.Type == "user" { + compare = func(m core.Member) bool { + return m.UserID == exp.Payer + } + } else if exp.Type == "non-user" { + compare = func(m core.Member) bool { + return m.Name == exp.Payer + } + } else { + c.String(http.StatusBadRequest, "Invalid member type") + return + } + idx := slices.IndexFunc(members, compare) + if idx != -1 { + expense, err := core.CreateExpense(db, plan, members[idx], exp.Amount) + if err != nil { + c.Status(http.StatusInternalServerError) + return + } + c.JSON(http.StatusOK, expense) + return + } + c.String(http.StatusBadRequest, "Unable to found member") +} + +func listExpenses(c *gin.Context) { + u := extractUser(db, c) + if u == nil { + c.Status(http.StatusUnauthorized) + return + } + + plan_id, err := strconv.ParseInt(c.Param("id"), 10, 32) + if err != nil || plan_id < 0 { + c.Status(http.StatusBadRequest) + return + } + plan, err := u.GetPlan(db, uint(plan_id)) + if err != nil { + c.Status(http.StatusBadRequest) + return + } + expenses, err := core.ListExpenses(db, plan) + if err != nil { + c.Status(http.StatusInternalServerError) + return + } + c.JSON(http.StatusOK, expenses) +} + +func getExpense(c *gin.Context) { + u := extractUser(db, c) + if u == nil { + c.Status(http.StatusUnauthorized) + return + } + + expense_id, err := strconv.ParseInt(c.Param("id"), 10, 32) + if err != nil || expense_id < 0 { + c.Status(http.StatusBadRequest) + return + } + expense, err := core.GetExpense(db, uint(expense_id)) + if err != nil { + c.Status(http.StatusBadRequest) + return + } + _, err = u.GetPlan(db, expense.PlanID) + if err != nil { + c.Status(http.StatusBadRequest) + return + } + + c.JSON(http.StatusOK, expense) +} + +func deleteExpense(c *gin.Context) { + u := extractUser(db, c) + if u == nil { + c.Status(http.StatusUnauthorized) + return + } + + expense_id, err := strconv.ParseInt(c.Param("id"), 10, 32) + if err != nil || expense_id < 0 { + c.Status(http.StatusBadRequest) + return + } + expense, err := core.GetExpense(db, uint(expense_id)) + if err != nil { + c.Status(http.StatusBadRequest) + return + } + _, err = u.GetPlan(db, expense.PlanID) + if err != nil { + c.Status(http.StatusBadRequest) + return + } + + if err = expense.Delete(db); err != nil { + c.Status(http.StatusInternalServerError) + return + } + + c.Status(http.StatusOK) +} + +func getExpenseDebts(c *gin.Context) { + u := extractUser(db, c) + if u == nil { + c.Status(http.StatusUnauthorized) + return + } + + expense_id, err := strconv.ParseInt(c.Param("id"), 10, 32) + if err != nil || expense_id < 0 { + c.Status(http.StatusBadRequest) + return + } + expense, err := core.GetExpense(db, uint(expense_id)) + if err != nil { + c.Status(http.StatusBadRequest) + return + } + _, err = u.GetPlan(db, expense.PlanID) + if err != nil { + c.Status(http.StatusBadRequest) + return + } + + debts, err := expense.GetDebt(db) + if err != nil { + c.Status(http.StatusInternalServerError) + return + } + + c.JSON(http.StatusOK, debts) +} + +func setExpenseDebts(c *gin.Context) { + u := extractUser(db, c) + if u == nil { + c.Status(http.StatusUnauthorized) + return + } + + expense_id, err := strconv.ParseInt(c.Param("id"), 10, 32) + if err != nil || expense_id < 0 { + c.Status(http.StatusBadRequest) + return + } + expense, err := core.GetExpense(db, uint(expense_id)) + if err != nil { + c.Status(http.StatusBadRequest) + return + } + plan, err := u.GetPlan(db, expense.PlanID) + if err != nil { + c.Status(http.StatusBadRequest) + return + } + expense.Plan = plan + + var debtSpec []core.DebtSpec + if err := c.ShouldBind(&debtSpec); err != nil { + c.Status(http.StatusBadRequest) + return + } + + //expense.SetDebt(db, debtSpec) + + c.JSON(http.StatusOK, debtSpec) +} + +func bindExpensesAPIs(r *gin.Engine) { + r.POST("/plans/:id/expenses", createExpense) + r.GET("/plans/:id/expenses", listExpenses) + r.GET("/expenses/:id", getExpense) + r.DELETE("/expenses/:id", deleteExpense) + r.GET("/expenses/:id/debts", getExpenseDebts) + r.POST("/expenses/:id/debts", setExpenseDebts) +} diff --git a/apis/plans.go b/apis/plans.go index 2b40112..5817034 100644 --- a/apis/plans.go +++ b/apis/plans.go @@ -119,7 +119,7 @@ func addPlanMember(c *gin.Context) { return } - err = plan.AddMember(db, &Member{Name: new_member.Name, Type: "non-user"}) + err = plan.AddMember(db, &Member{Name: new_member.Name}) if err == nil { c.JSON(http.StatusOK, new_member) @@ -171,7 +171,7 @@ func joinPlan(c *gin.Context) { c.String(http.StatusConflict, "User already a member") return } - plan.AddMember(db, &Member{Type: "user", UserID: user.Username}) + plan.AddMember(db, &Member{UserID: user.Username}) c.Status(http.StatusOK) } diff --git a/core/expenses.go b/core/expenses.go new file mode 100644 index 0000000..31dbbb8 --- /dev/null +++ b/core/expenses.go @@ -0,0 +1,137 @@ +package core + +import ( + "encoding/json" + "errors" + + "github.com/shopspring/decimal" + "gorm.io/gorm" +) + +type DebtType int + +const ( + ProportionalDebt DebtType = iota + AbsoluteDebt = iota +) + +func (dt DebtType) String() string { + return [...]string{"proportional", "absolute"}[dt] +} + +func (dt DebtType) MarshalJSON() ([]byte, error) { + return json.Marshal(dt.String()) +} + +func (dt *DebtType) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return nil + } + if s == "proportional" { + *dt = ProportionalDebt + } else if s == "absolute" { + *dt = AbsoluteDebt + } else { + return errors.New("Invalid value for debt type") + } + return nil +} + +type DebtSpec struct { + Member Member `json:"member"` + Amount decimal.Decimal `json:"amount"` + DebtType DebtType `json:"debt_type"` +} + +func CreateExpense(db *gorm.DB, plan Plan, payer Member, amount decimal.Decimal) (*Expense, error) { + expense := new(Expense) + expense.Plan = plan + expense.Payer = payer + expense.Amount = amount + err := db.Create(&expense).Error + if err != nil { + return nil, err + } + return expense, nil +} + +func (e *Expense) GetDebt(db *gorm.DB) ([]Debt, error) { + var debts []Debt + if err := db.Model(e).Association("Debts").Find(&debts); err != nil { + return []Debt{}, err + } + return debts, nil +} + +func (e *Expense) SetDebt(db *gorm.DB, debtSpec []DebtSpec) error { + abs_paid := decimal.Decimal{} + debts := make([]Debt, 0) + var prop_payers int64 = 0 + for _, debt := range debtSpec { + if debt.DebtType == ProportionalDebt { + prop_payers += 1 + } else { + debts = append(debts, Debt{ + Amount: debt.Amount, + Paid: decimal.Decimal{}, + ExpenseID: e.ID, + DebtorID: debt.Member.ID, + }) + abs_paid.Add(debt.Amount) + } + } + if abs_paid.Cmp(e.Amount) > 0 { + return errors.New("Absolute pay amount is larger than debt") + } + + prop_debt := e.Amount.Sub(abs_paid).DivRound(decimal.NewFromInt(prop_payers), 2) + return db.Transaction(func(tx *gorm.DB) error { + for _, debt := range debts { + if err := db.Create(&debt).Error; err != nil { + return err + } + } + + for _, debt := range debtSpec { + if debt.DebtType != ProportionalDebt { + continue + } + if err := db.Create(&Debt{ + Amount: prop_debt, + Paid: decimal.Decimal{}, + ExpenseID: e.ID, + DebtorID: debt.Member.ID, + }).Error; err != nil { + return err + } + } + return nil + }) +} + +func (e *Expense) Delete(db *gorm.DB) error { + return db.Transaction(func(tx *gorm.DB) error { + if err := db.Model(e).Association("Debts").Delete(&[]Debt{}); err != nil { + return err + } + return db.Delete(e, e).Error + }) +} + +func ListExpenses(db *gorm.DB, plan Plan) ([]Expense, error) { + var expenses []Expense + if err := db.Where("plan_id = ?", plan.ID).Find(&expenses).Error; err != nil { + return []Expense{}, err + } + return expenses, nil +} + +func GetExpense(db *gorm.DB, expense_id uint) (*Expense, error) { + expense := new(Expense) + expense.ID = expense_id + if err := db.Take(expense).Error; err != nil { + return nil, err + } + return expense, nil +} diff --git a/core/models.go b/core/models.go index 53268c0..f4bceb8 100644 --- a/core/models.go +++ b/core/models.go @@ -1,5 +1,7 @@ package core +import "github.com/shopspring/decimal" + type User struct { Username string `gorm:"primaryKey" json:"username"` Password string `json:"password"` @@ -12,9 +14,8 @@ type Member struct { ID uint `gorm:"primaryKey;autoIncrement:true" json:"-"` PlanID uint `json:"-"` Plan Plan `json:"-"` - Type string `gorm:"check:type in ('user','non-user')" json:"type"` - Name string `gorm:"check:type=='member' OR name IS NOT NULL" json:"name"` - UserID string `json:"username"` + Name string `gorm:"check:user_id IS NOT NULL OR name IS NOT NULL" json:"name,omitempty"` + UserID string `json:"username,omitempty"` User User `gorm:"foreignKey:UserID" json:"-"` } @@ -30,6 +31,25 @@ type Plan struct { Polls []Poll `gorm:"foreignKey:PlanID;references:ID" json:"-"` } +type Expense struct { + ID uint `gorm:"primaryKey;autoIncrement:true" json:"id"` + PlanID uint `json:"-"` + Plan Plan `json:"-"` + PayerID uint `json:"-"` + Payer Member `gorm:"foreignKey:PayerID" json:"-"` + Amount decimal.Decimal `json:"amount"` + Debts []Debt `gorm:"foreignKey:ExpenseID" json:"debts,omitempty"` +} + +type Debt struct { + ExpenseID uint `gorm:"primaryKey" json:"-"` + Expense Expense `gorm:"foreignKey:ExpenseID"` + DebtorID uint `gorm:"primaryKey" json:"-"` + Debtor Member `gorm:"foreignKey:DebtorID"` + Amount decimal.Decimal + Paid decimal.Decimal +} + // CREATE TABLE polls(id INTEGER PRIMARY KEY AUTOINCREMENT, plan INTEGER, name STRING, options JSON, FOREIGN KEY plan REFERENCES plans(id)) type Poll struct { ID uint `gorm:"primaryKey;autoIncrement:true" json:"id"` diff --git a/core/plans.go b/core/plans.go index 32ad597..28403c9 100644 --- a/core/plans.go +++ b/core/plans.go @@ -21,7 +21,6 @@ func CreatePlan(db *gorm.DB, user *User, name string) (*Plan, error) { Members: []Member{ { UserID: user.Username, - Type: "user", }, }, JoinCode: base64.URLEncoding.EncodeToString(join_code), @@ -95,10 +94,7 @@ func (p *Plan) AddMember(orm *gorm.DB, new_member *Member) error { return errors.New("Member is nil") } new_member.PlanID = p.ID - if new_member.Type == "non-user" { - if new_member.Name == "" { - return errors.New("name required for non user") - } + if new_member.Name != "" { found, err := p.HasNonUser(orm, new_member.Name) if err != nil { return nil @@ -107,7 +103,7 @@ func (p *Plan) AddMember(orm *gorm.DB, new_member *Member) error { return errors.New("Non user name taken") } return orm.Create(&new_member).Error - } else if new_member.Type == "user" { + } else if new_member.UserID != "" { user, err := GetUser(orm, new_member.UserID) if err != nil { return err @@ -123,6 +119,6 @@ func (p *Plan) AddMember(orm *gorm.DB, new_member *Member) error { return orm.Create(&new_member).Error } else { - return errors.New("Invalid type for user") + return errors.New("Member object requires one of Name or UserID to be filled") } } diff --git a/db.go b/db.go index 1f49809..b0e63a8 100644 --- a/db.go +++ b/db.go @@ -19,7 +19,7 @@ func bootstrapDatabase() *gorm.DB { return nil } - db.AutoMigrate(&User{}, &Member{}, &Plan{}, &Poll{}, &Vote{}) + db.AutoMigrate(&User{}, &Member{}, &Plan{}, &Expense{}, &Debt{}, &Poll{}, &Vote{}) //var tables = [...]struct { // key string diff --git a/go.mod b/go.mod index 581ef33..9bef346 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect + github.com/shopspring/decimal v1.4.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect golang.org/x/arch v0.10.0 // indirect diff --git a/go.sum b/go.sum index 45e3183..96934a0 100644 --- a/go.sum +++ b/go.sum @@ -54,6 +54,8 @@ github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNH github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=