package core import ( "database/sql" "encoding/json" "errors" "fmt" "github.com/shopspring/decimal" ) type DebtType int const ( ProportionalDebt DebtType = iota AbsoluteDebt = iota ) func (dt DebtType) String() string { return [...]string{"proportional", "absolute"}[dt] } func (dt DebtType) MarshalJSON() ([]byte, error) { return json.Marshal(dt.String()) } func (dt *DebtType) UnmarshalJSON(data []byte) error { var s string if err := json.Unmarshal(data, &s); err != nil { return nil } if s == "proportional" { *dt = ProportionalDebt } else if s == "absolute" { *dt = AbsoluteDebt } else { return errors.New("Invalid value for debt type") } return nil } type DebtSpec struct { Member Member `json:"member"` Amount decimal.Decimal `json:"amount"` DebtType DebtType `json:"debt_type"` } func ExpenseCreate(db *sql.DB, plan Plan, payer Member, amount decimal.Decimal) (*Expense, error) { expense := new(Expense) expense.PlanId = plan.Id expense.PayerId = payer.Id expense.Amount = amount row := db.QueryRow( "INSERT INTO expenses(plan_id,payer_id,amount) VALUES (?,?,?) RETURNING id", expense.PlanId, expense.PayerId, expense.Amount, ) if err := row.Scan(&expense.Id); err != nil { return nil, err } return expense, nil } func (e *Expense) GetDebt(db *sql.DB) ([]Debt, error) { var debts []Debt rows, err := db.Query("SELECT expense_id,debtor_id,amount,paid FROM debts WHERE expense_id=?") if err != nil { return debts, err } defer rows.Close() for rows.Next() { var d Debt err = rows.Scan(&d.ExpenseId, &d.DebtorId, &d.Amount, &d.Paid) if err != nil { return []Debt{}, err } debts = append(debts, d) } return debts, nil } func (e *Expense) SetDebt(db *sql.DB, debtSpec []DebtSpec) error { abs_paid := decimal.Decimal{} debts := make([]Debt, 0) var prop_payers int64 = 0 for _, debt := range debtSpec { if debt.DebtType == ProportionalDebt { prop_payers += 1 } else { debts = append(debts, Debt{ Amount: debt.Amount, Paid: decimal.Decimal{}, ExpenseId: e.Id, DebtorId: debt.Member.Id, }) abs_paid.Add(debt.Amount) } } if abs_paid.Cmp(e.Amount) > 0 { return errors.New("Absolute pay amount is larger than debt") } prop_debt := e.Amount.Sub(abs_paid).DivRound(decimal.NewFromInt(prop_payers), 2) tx, err := db.Begin() if err != nil { return err } defer tx.Rollback() for _, debt := range debts { if err := debt.Create(tx); err != nil { return err } } for _, debt := range debtSpec { if debt.DebtType != ProportionalDebt { continue } debtObj := Debt{ Amount: prop_debt, Paid: decimal.Decimal{}, ExpenseId: e.Id, DebtorId: debt.Member.Id, } if err := debtObj.Create(tx); err != nil { return err } } return tx.Commit() } func (e *Expense) Delete(db *sql.DB) error { tx, err := db.Begin() if err != nil { return err } defer tx.Rollback() _, err = tx.Exec("DELETE FROM debts WHERE expense_id=?", e.Id) if err != nil { return fmt.Errorf("Unable to delete debts for expense %d: %w", e.Id, err) } _, err = tx.Exec("DELETE FROM expenses WHERE id=?", e.Id) if err != nil { return fmt.Errorf("Unable to delete expense %d: %w", e.Id, err) } err = tx.Commit() if err != nil { return fmt.Errorf( "Unable to commit transaction when deleting expense %d: %w", e.Id, err, ) } return nil } func ExpensesList(db *sql.DB, plan Plan) ([]Expense, error) { var expenses []Expense rows, err := db.Query( "SELECT id,plan_id,payer_id,amount FROM expenses WHERE plan_id=$1", plan.Id, ) if err != nil { return expenses, fmt.Errorf( "Unable to query for expenses in plan %d: %w", plan.Id, err, ) } defer rows.Close() for rows.Next() { e := Expense{} err = rows.Scan(&e.Id, e.PlanId, e.PayerId, e.Amount) if err != nil { return expenses, fmt.Errorf( "Unable to scan expense during list: %w", err, ) } expenses = append(expenses, e) } return expenses, nil } func ExpensesGet(db *sql.DB, expense_id uint) (*Expense, error) { e := Expense{} row := db.QueryRow( "SELECT id,plan_id,payer_id,amount FROM expenses WHERE id=?", expense_id, ) if err := row.Scan(&e.Id, &e.PlanId, &e.PayerId, &e.Amount); err != nil { return nil, fmt.Errorf("Unable to get expense %d: %w", expense_id, err) } return &e, nil }