Commit 145af892 authored by Sybren A. Stüvel's avatar Sybren A. Stüvel

Added comparisons to Amount struct

parent b3367d59
......@@ -29,8 +29,10 @@ import (
"strings"
)
// Errors returned by the Amount functions.
var (
errInvalidJSONString = errors.New("Invalid JSON, field must be a string \"xxx.cc\"")
ErrInvalidJSONString = errors.New("Invalid JSON, field must be a string \"xxx.cc\"")
ErrMixedCurrency = errors.New("amounts have different currencies")
)
// Amount describes a monetary amount.
......@@ -44,8 +46,83 @@ type Amount struct {
// Cents represents an integer number of cents (1/100th unit of currency).
type Cents int
// NewAmount creates a new amount of integer cents.
func NewAmount(currency string, cents int) Amount {
return Amount{
Currency: currency,
Value: Cents(cents),
}
}
// Add returns the sum. Currencies must match.
func (amt Amount) Add(other Amount) (sum Amount, err error) {
if amt.Currency != other.Currency {
err = ErrMixedCurrency
return
}
sum = Amount{
Currency: amt.Currency,
Value: amt.Value.Add(other.Value),
}
return
}
// Sub returns the difference. Currencies must match.
func (amt Amount) Sub(other Amount) (sum Amount, err error) {
if amt.Currency != other.Currency {
err = ErrMixedCurrency
return
}
sum = Amount{
Currency: amt.Currency,
Value: amt.Value.Sub(other.Value),
}
return
}
func (amt Amount) cmp(other Amount) (int, error) {
if amt.Currency != other.Currency {
return 0, ErrMixedCurrency
}
diff := amt.Value.Sub(other.Value).Cents()
switch {
case diff < 0:
return -1, nil
case diff > 0:
return 1, nil
default:
return 0, nil
}
}
// LessThan returns true iff amt < other
func (amt Amount) LessThan(other Amount) (bool, error) {
cmp, err := amt.cmp(other)
if err != nil {
return false, err
}
return cmp < 0, nil
}
// GreaterThan returns true iff amt > other
func (amt Amount) GreaterThan(other Amount) (bool, error) {
cmp, err := amt.cmp(other)
if err != nil {
return false, err
}
return cmp > 0, nil
}
// Equals returns true if both amounts are equal in currency and value.
func (amt Amount) Equals(other Amount) bool {
return amt.Currency == other.Currency && amt.Value == other.Value
}
func (amt Amount) String() string {
return fmt.Sprintf("%s %s", amt.Value.String(), amt.Currency)
return fmt.Sprintf("%s %s", amt.Currency, amt.Value.String())
}
// Cents returns the number of cents as int.
......@@ -53,6 +130,16 @@ func (c Cents) Cents() int {
return int(c)
}
// Add returns the sum
func (c Cents) Add(other Cents) Cents {
return Cents(int(c) + int(other))
}
// Sub returns the difference
func (c Cents) Sub(other Cents) Cents {
return Cents(int(c) - int(other))
}
func (c Cents) String() string {
cents := int(c)
return fmt.Sprintf("%d.%02d", cents/100, cents%100)
......@@ -62,13 +149,13 @@ func (c Cents) String() string {
func (c *Cents) UnmarshalJSON(input []byte) error {
inLen := len(input)
if inLen < 2 || input[0] != '"' || input[inLen-1] != '"' {
return errInvalidJSONString
return ErrInvalidJSONString
}
asString := string(input[1 : inLen-1])
parts := strings.Split(asString, ".")
if len(parts) != 2 {
return errInvalidJSONString
return ErrInvalidJSONString
}
wholeUnits, err := strconv.Atoi(parts[0])
......@@ -80,7 +167,7 @@ func (c *Cents) UnmarshalJSON(input []byte) error {
return err
}
if cents >= 100 {
return errInvalidJSONString
return ErrInvalidJSONString
}
*c = Cents(wholeUnits*100 + cents)
......
......@@ -54,3 +54,41 @@ func TestCents_MarshalJSON_Happy(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, `{"value":"123.45","currency":"EUR"}`, string(asJSON))
}
func TestAmount_Comparisons(t *testing.T) {
a := NewAmount("EUR", 47)
b := NewAmount("EUR", 327)
sum, err := a.Add(b)
assert.Nil(t, err)
assert.Equal(t, a.Currency, sum.Currency)
assert.Equal(t, 47+327, sum.Value.Cents())
assert.Equal(t, 47, a.Value.Cents())
assert.Equal(t, 327, b.Value.Cents())
diff, err := a.Sub(b)
assert.Nil(t, err)
assert.Equal(t, a.Currency, diff.Currency)
assert.Equal(t, 47-327, diff.Value.Cents())
assert.Equal(t, 47, a.Value.Cents())
assert.Equal(t, 327, b.Value.Cents())
isLess, err := a.LessThan(b)
assert.Nil(t, err)
assert.True(t, isLess)
isLess, err = b.LessThan(a)
assert.Nil(t, err)
assert.False(t, isLess)
isGreater, err := a.GreaterThan(b)
assert.Nil(t, err)
assert.False(t, isGreater)
isGreater, err = b.GreaterThan(a)
assert.Nil(t, err)
assert.True(t, isGreater)
assert.True(t, a.Equals(a))
assert.True(t, a.Equals(NewAmount("EUR", 47)))
assert.False(t, a.Equals(b))
assert.False(t, a.Equals(NewAmount("HRK", a.Value.Cents())))
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment