package core import ( "encoding/json" "errors" "github.com/shopspring/decimal" "gorm.io/gorm" ) type DebtType int const ( ProportionalDebt DebtType = iota AbsoluteDebt = iota ) func (dt DebtType) String() string { return [...]string{"proportional", "absolute"}[dt] } func (dt DebtType) MarshalJSON() ([]byte, error) { return json.Marshal(dt.String()) } func (dt *DebtType) UnmarshalJSON(data []byte) error { var s string if err := json.Unmarshal(data, &s); err != nil { return nil } if s == "proportional" { *dt = ProportionalDebt } else if s == "absolute" { *dt = AbsoluteDebt } else { return errors.New("Invalid value for debt type") } return nil } type DebtSpec struct { Member Member `json:"member"` Amount decimal.Decimal `json:"amount"` DebtType DebtType `json:"debt_type"` } func CreateExpense(db *gorm.DB, plan Plan, payer Member, amount decimal.Decimal) (*Expense, error) { expense := new(Expense) expense.Plan = plan expense.Payer = payer expense.Amount = amount err := db.Create(&expense).Error if err != nil { return nil, err } return expense, nil } func (e *Expense) GetDebt(db *gorm.DB) ([]Debt, error) { var debts []Debt if err := db.Model(e).Association("Debts").Find(&debts); err != nil { return []Debt{}, err } return debts, nil } func (e *Expense) SetDebt(db *gorm.DB, debtSpec []DebtSpec) error { abs_paid := decimal.Decimal{} debts := make([]Debt, 0) var prop_payers int64 = 0 for _, debt := range debtSpec { if debt.DebtType == ProportionalDebt { prop_payers += 1 } else { debts = append(debts, Debt{ Amount: debt.Amount, Paid: decimal.Decimal{}, ExpenseID: e.ID, DebtorID: debt.Member.ID, }) abs_paid.Add(debt.Amount) } } if abs_paid.Cmp(e.Amount) > 0 { return errors.New("Absolute pay amount is larger than debt") } prop_debt := e.Amount.Sub(abs_paid).DivRound(decimal.NewFromInt(prop_payers), 2) return db.Transaction(func(tx *gorm.DB) error { for _, debt := range debts { if err := db.Create(&debt).Error; err != nil { return err } } for _, debt := range debtSpec { if debt.DebtType != ProportionalDebt { continue } if err := db.Create(&Debt{ Amount: prop_debt, Paid: decimal.Decimal{}, ExpenseID: e.ID, DebtorID: debt.Member.ID, }).Error; err != nil { return err } } return nil }) } func (e *Expense) Delete(db *gorm.DB) error { return db.Transaction(func(tx *gorm.DB) error { if err := db.Model(e).Association("Debts").Delete(&[]Debt{}); err != nil { return err } return db.Delete(e, e).Error }) } func ListExpenses(db *gorm.DB, plan Plan) ([]Expense, error) { var expenses []Expense if err := db.Where("plan_id = ?", plan.ID).Find(&expenses).Error; err != nil { return []Expense{}, err } return expenses, nil } func GetExpense(db *gorm.DB, expense_id uint) (*Expense, error) { expense := new(Expense) expense.ID = expense_id if err := db.Take(expense).Error; err != nil { return nil, err } return expense, nil }