diff --git a/go.mod b/go.mod index bdcd2cd..86d9488 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( ) require ( + github.com/shopspring/decimal v1.4.0 // indirect github.com/spf13/pflag v1.0.10 // indirect golang.org/x/mod v0.27.0 // indirect golang.org/x/tools v0.36.0 // indirect diff --git a/go.sum b/go.sum index d1918fa..101ec92 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= diff --git a/internal/errors.go b/internal/errors.go new file mode 100644 index 0000000..8e6c7a8 --- /dev/null +++ b/internal/errors.go @@ -0,0 +1,5 @@ +package internal + +import "fmt" + +var ErrInsufficientBoughtVolume = fmt.Errorf("insufficient bought volume") diff --git a/internal/filler.go b/internal/filler.go new file mode 100644 index 0000000..2d03e9d --- /dev/null +++ b/internal/filler.go @@ -0,0 +1,96 @@ +package internal + +import ( + "container/list" + + "github.com/shopspring/decimal" +) + +type Filler struct { + Record + + filled decimal.Decimal +} + +func NewFiller(r Record) *Filler { + return &Filler{ + Record: r, + } +} + +// Fill accrues some quantity. Returns how mutch was accrued in the 1st return value and whether +// it was filled or not on the 2nd return value. +func (f *Filler) Fill(quantity decimal.Decimal) (decimal.Decimal, bool) { + unfilled := f.Record.Quantity().Sub(f.filled) + delta := decimal.Min(unfilled, quantity) + f.filled = f.filled.Add(delta) + return delta, f.IsFilled() +} + +// IsFilled returns true if the fill is equal to the record quantity. +func (f *Filler) IsFilled() bool { + return f.filled.Equal(f.Quantity()) +} + +type FillerQueue struct { + l *list.List +} + +// Push inserts the Filler at the back of the queue. +func (fq *FillerQueue) Push(f *Filler) { + if f == nil { + return + } + + if fq == nil { + // This would cause a panic anyway so, we panic with a more meaningful message + panic("Push to nil FillerQueue") + } + + if fq.l == nil { + fq.l = list.New() + } + + fq.l.PushBack(f) +} + +// Pop removes and returns the first Filler of the queue in the 1st return value. If the list is +// empty returns false on the 2nd return value, true otherwise. +func (fq *FillerQueue) Pop() (*Filler, bool) { + el := fq.frontElement() + if el == nil { + return nil, false + } + + val := fq.l.Remove(el) + + return val.(*Filler), true +} + +// Peek returns the front Filler of the queue in the 1st return value. If the list is empty returns +// false on the 2nd return value, true otherwise. +func (fq *FillerQueue) Peek() (*Filler, bool) { + el := fq.frontElement() + if el == nil { + return nil, false + } + + return el.Value.(*Filler), true +} + +func (fq *FillerQueue) frontElement() *list.Element { + if fq == nil || fq.l == nil { + return nil + } + + return fq.l.Front() +} + +// Len returns how many elements are currently on the queue +func (fq *FillerQueue) Len() int { + if fq == nil || fq.l == nil { + return 0 + } + + return fq.l.Len() +} diff --git a/internal/filler_test.go b/internal/filler_test.go new file mode 100644 index 0000000..3609232 --- /dev/null +++ b/internal/filler_test.go @@ -0,0 +1,187 @@ +package internal + +import ( + "testing" + + "github.com/shopspring/decimal" +) + +func TestFillerQueue(t *testing.T) { + var recCount int + newRecord := func() Record { + recCount++ + return testRecord{ + id: recCount, + } + } + + var rq FillerQueue + + if rq.Len() != 0 { + t.Fatalf("zero value should have zero lenght") + } + + _, ok := rq.Pop() + if ok { + t.Fatalf("Pop() should return (_,false) on a zero value") + } + + _, ok = rq.Peek() + if ok { + t.Fatalf("Peek() should return (_,false) on a zero value") + } + + rq.Push(nil) + if rq.Len() != 0 { + t.Fatalf("pushing nil should be a no-op") + } + + rq.Push(NewFiller(newRecord())) + if rq.Len() != 1 { + t.Fatalf("pushing 1st record should result in lenght of 1") + } + + rq.Push(NewFiller(newRecord())) + if rq.Len() != 2 { + t.Fatalf("pushing 2nd record should result in lenght of 2") + } + + peekFiller, ok := rq.Peek() + if !ok { + t.Fatalf("Peek() should return (_,true) when the list is not empty") + } + + if rec, ok := peekFiller.Record.(testRecord); ok { + if rec.id != 1 { + t.Fatalf("Peek() should return the 1st record pushed but returned %d", rec.id) + } + } else { + t.Fatalf("Peek() should return the original record type") + } + + if rq.Len() != 2 { + t.Fatalf("Peek() should not affect the list length") + } + + popFiller, ok := rq.Pop() + if !ok { + t.Fatalf("Pop() should return (_,true) when the list is not empty") + } + + if rec, ok := popFiller.Record.(testRecord); ok { + if rec.id != 1 { + t.Fatalf("Pop() should return the first record pushed but returned %d", rec.id) + } + } else { + t.Fatalf("Pop() should return the original record") + } + + if rq.Len() != 1 { + t.Fatalf("Pop() should remove an element from the list") + } +} + +func TestFillerQueueNilReceiver(t *testing.T) { + var rq *FillerQueue + + if rq.Len() > 0 { + t.Fatalf("nil receiver should have zero lenght") + } + + _, ok := rq.Peek() + if ok { + t.Fatalf("Peek() on a nil receiver should return (_,false)") + } + + _, ok = rq.Pop() + if ok { + t.Fatalf("Pop() on a nil receiver should return (_,false)") + } + + rq.Push(nil) + if rq.Len() != 0 { + t.Fatalf("Push(nil) on a nil receiver should be a no-op") + } + + defer func() { + r := recover() + if r == nil { + t.Fatalf("expected a panic but got nothing") + } + + expMsg := "Push to nil FillerQueue" + if msg, ok := r.(string); !ok || msg != expMsg { + t.Fatalf(`want panic message %q but got "%v"`, expMsg, r) + } + }() + rq.Push(NewFiller(nil)) +} + +type testRecord struct { + Record + + id int + quantity decimal.Decimal +} + +func (tr testRecord) Quantity() decimal.Decimal { + return tr.quantity +} + +func TestFiller_Fill(t *testing.T) { + tests := []struct { + name string + r Record + quantity decimal.Decimal + want decimal.Decimal + wantBool bool + }{ + { + name: "fills 0 of zero quantity", + r: &testRecord{quantity: decimal.NewFromFloat(0.0)}, + quantity: decimal.Decimal{}, + want: decimal.Decimal{}, + wantBool: true, + }, + { + name: "fills 0 of positive quantity", + r: &testRecord{quantity: decimal.NewFromFloat(100.0)}, + quantity: decimal.Decimal{}, + want: decimal.Decimal{}, + wantBool: false, + }, + { + name: "fills 10 out of 100 and no previous fills", + r: &testRecord{quantity: decimal.NewFromFloat(100.0)}, + quantity: decimal.NewFromFloat(10), + want: decimal.NewFromFloat(10), + wantBool: false, + }, + { + name: "fills 10 out of 10 and no previous fills", + r: &testRecord{quantity: decimal.NewFromFloat(10.0)}, + quantity: decimal.NewFromFloat(10), + want: decimal.NewFromFloat(10), + wantBool: true, + }, + { + name: "filling 100 fills 10 out of 10 and no previous fills", + r: &testRecord{quantity: decimal.NewFromFloat(10.0)}, + quantity: decimal.NewFromFloat(100), + want: decimal.NewFromFloat(10), + wantBool: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := NewFiller(tt.r) + got, gotBool := f.Fill(tt.quantity) + if !tt.want.Equal(got) { + t.Errorf("want 1st return value to be %v but got %v", tt.want, got) + } + if tt.wantBool != gotBool { + t.Errorf("want 2nd return value to be %v but got %v", tt.wantBool, gotBool) + } + }) + } +} diff --git a/internal/mocks/mocks_gen.go b/internal/mocks/mocks_gen.go index 463d696..104d0bf 100644 --- a/internal/mocks/mocks_gen.go +++ b/internal/mocks/mocks_gen.go @@ -11,11 +11,11 @@ package mocks import ( context "context" - big "math/big" reflect "reflect" time "time" internal "github.com/nmoniz/any2anexoj/internal" + decimal "github.com/shopspring/decimal" gomock "go.uber.org/mock/gomock" ) @@ -107,10 +107,10 @@ func (m *MockRecord) EXPECT() *MockRecordMockRecorder { } // Fees mocks base method. -func (m *MockRecord) Fees() *big.Float { +func (m *MockRecord) Fees() decimal.Decimal { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Fees") - ret0, _ := ret[0].(*big.Float) + ret0, _ := ret[0].(decimal.Decimal) return ret0 } @@ -127,28 +127,28 @@ type MockRecordFeesCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockRecordFeesCall) Return(arg0 *big.Float) *MockRecordFeesCall { +func (c *MockRecordFeesCall) Return(arg0 decimal.Decimal) *MockRecordFeesCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MockRecordFeesCall) Do(f func() *big.Float) *MockRecordFeesCall { +func (c *MockRecordFeesCall) Do(f func() decimal.Decimal) *MockRecordFeesCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockRecordFeesCall) DoAndReturn(f func() *big.Float) *MockRecordFeesCall { +func (c *MockRecordFeesCall) DoAndReturn(f func() decimal.Decimal) *MockRecordFeesCall { c.Call = c.Call.DoAndReturn(f) return c } // Price mocks base method. -func (m *MockRecord) Price() *big.Float { +func (m *MockRecord) Price() decimal.Decimal { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Price") - ret0, _ := ret[0].(*big.Float) + ret0, _ := ret[0].(decimal.Decimal) return ret0 } @@ -165,28 +165,28 @@ type MockRecordPriceCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockRecordPriceCall) Return(arg0 *big.Float) *MockRecordPriceCall { +func (c *MockRecordPriceCall) Return(arg0 decimal.Decimal) *MockRecordPriceCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MockRecordPriceCall) Do(f func() *big.Float) *MockRecordPriceCall { +func (c *MockRecordPriceCall) Do(f func() decimal.Decimal) *MockRecordPriceCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockRecordPriceCall) DoAndReturn(f func() *big.Float) *MockRecordPriceCall { +func (c *MockRecordPriceCall) DoAndReturn(f func() decimal.Decimal) *MockRecordPriceCall { c.Call = c.Call.DoAndReturn(f) return c } // Quantity mocks base method. -func (m *MockRecord) Quantity() *big.Float { +func (m *MockRecord) Quantity() decimal.Decimal { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Quantity") - ret0, _ := ret[0].(*big.Float) + ret0, _ := ret[0].(decimal.Decimal) return ret0 } @@ -203,19 +203,19 @@ type MockRecordQuantityCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockRecordQuantityCall) Return(arg0 *big.Float) *MockRecordQuantityCall { +func (c *MockRecordQuantityCall) Return(arg0 decimal.Decimal) *MockRecordQuantityCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MockRecordQuantityCall) Do(f func() *big.Float) *MockRecordQuantityCall { +func (c *MockRecordQuantityCall) Do(f func() decimal.Decimal) *MockRecordQuantityCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockRecordQuantityCall) DoAndReturn(f func() *big.Float) *MockRecordQuantityCall { +func (c *MockRecordQuantityCall) DoAndReturn(f func() decimal.Decimal) *MockRecordQuantityCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -297,10 +297,10 @@ func (c *MockRecordSymbolCall) DoAndReturn(f func() string) *MockRecordSymbolCal } // Taxes mocks base method. -func (m *MockRecord) Taxes() *big.Float { +func (m *MockRecord) Taxes() decimal.Decimal { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Taxes") - ret0, _ := ret[0].(*big.Float) + ret0, _ := ret[0].(decimal.Decimal) return ret0 } @@ -317,19 +317,19 @@ type MockRecordTaxesCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockRecordTaxesCall) Return(arg0 *big.Float) *MockRecordTaxesCall { +func (c *MockRecordTaxesCall) Return(arg0 decimal.Decimal) *MockRecordTaxesCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MockRecordTaxesCall) Do(f func() *big.Float) *MockRecordTaxesCall { +func (c *MockRecordTaxesCall) Do(f func() decimal.Decimal) *MockRecordTaxesCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockRecordTaxesCall) DoAndReturn(f func() *big.Float) *MockRecordTaxesCall { +func (c *MockRecordTaxesCall) DoAndReturn(f func() decimal.Decimal) *MockRecordTaxesCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/internal/record.go b/internal/record.go index 03c2975..1fe43be 100644 --- a/internal/record.go +++ b/internal/record.go @@ -1,80 +1,17 @@ package internal import ( - "container/list" - "math/big" "time" + + "github.com/shopspring/decimal" ) type Record interface { Symbol() string Side() Side - Price() *big.Float - Quantity() *big.Float + Price() decimal.Decimal + Quantity() decimal.Decimal Timestamp() time.Time - Fees() *big.Float - Taxes() *big.Float -} - -type RecordQueue struct { - l *list.List -} - -// Push inserts the Record at the back of the queue. If pushing a nil Record then it's a no-op. -func (rq *RecordQueue) Push(r Record) { - if r == nil { - return - } - - if rq == nil { - // This would cause a panic anyway so, we panic with a more meaningful message - panic("Push to nil RecordQueue") - } - - if rq.l == nil { - rq.l = list.New() - } - - rq.l.PushBack(r) -} - -// Pop removes and returns the first Record of the queue in the 1st return value. If the list is -// empty returns false on the 2nd return value, true otherwise. -func (rq *RecordQueue) Pop() (Record, bool) { - el := rq.frontElement() - if el == nil { - return nil, false - } - - val := rq.l.Remove(el) - - return val.(Record), true -} - -// Peek returns the front Record of the queue in the 1st return value. If the list is empty returns -// false on the 2nd return value, true otherwise. -func (rq *RecordQueue) Peek() (Record, bool) { - el := rq.frontElement() - if el == nil { - return nil, false - } - - return el.Value.(Record), true -} - -func (rq *RecordQueue) frontElement() *list.Element { - if rq == nil || rq.l == nil { - return nil - } - - return rq.l.Front() -} - -// Len returns how many elements are currently on the queue -func (rq *RecordQueue) Len() int { - if rq == nil || rq.l == nil { - return 0 - } - - return rq.l.Len() + Fees() decimal.Decimal + Taxes() decimal.Decimal } diff --git a/internal/record_test.go b/internal/record_test.go deleted file mode 100644 index e7b237f..0000000 --- a/internal/record_test.go +++ /dev/null @@ -1,122 +0,0 @@ -package internal - -import ( - "testing" -) - -func TestRecordQueue(t *testing.T) { - var recCount int - newRecord := func() Record { - recCount++ - return testRecord{ - id: recCount, - } - } - - var rq RecordQueue - - if rq.Len() != 0 { - t.Fatalf("zero value should have zero lenght") - } - - _, ok := rq.Pop() - if ok { - t.Fatalf("Pop() should return (_,false) on a zero value") - } - - _, ok = rq.Peek() - if ok { - t.Fatalf("Peek() should return (_,false) on a zero value") - } - - rq.Push(nil) - if rq.Len() != 0 { - t.Fatalf("pushing nil should be a no-op") - } - - rq.Push(newRecord()) - if rq.Len() != 1 { - t.Fatalf("pushing 1st record should result in lenght of 1") - } - - rq.Push(newRecord()) - if rq.Len() != 2 { - t.Fatalf("pushing 2nd record should result in lenght of 2") - } - - peekRec, ok := rq.Peek() - if !ok { - t.Fatalf("Peek() should return (_,true) when the list is not empty") - } - - if peekRec, ok := peekRec.(testRecord); ok { - if peekRec.id != 1 { - t.Fatalf("Peek() should return the 1st record pushed but returned %d", peekRec.id) - } - } else { - t.Fatalf("Peek() should return the original record type") - } - - if rq.Len() != 2 { - t.Fatalf("Peek() should not affect the list length") - } - - popRec, ok := rq.Pop() - if !ok { - t.Fatalf("Pop() should return (_,true) when the list is not empty") - } - - if rec, ok := popRec.(testRecord); ok { - if rec.id != 1 { - t.Fatalf("Pop() should return the first record pushed but returned %d", rec.id) - } - } else { - t.Fatalf("Pop() should return the original record") - } - - if rq.Len() != 1 { - t.Fatalf("Pop() should remove an element from the list") - } -} - -func TestRecordQueueNilReceiver(t *testing.T) { - var rq *RecordQueue - - if rq.Len() > 0 { - t.Fatalf("nil receiver should have zero lenght") - } - - _, ok := rq.Peek() - if ok { - t.Fatalf("Peek() on a nil receiver should return (_,false)") - } - - _, ok = rq.Pop() - if ok { - t.Fatalf("Pop() on a nil receiver should return (_,false)") - } - - rq.Push(nil) - if rq.Len() != 0 { - t.Fatalf("Push(nil) on a nil receiver should be a no-op") - } - - defer func() { - r := recover() - if r == nil { - t.Fatalf("expected a panic but got nothing") - } - - expMsg := "Push to nil RecordQueue" - if msg, ok := r.(string); !ok || msg != expMsg { - t.Fatalf(`want panic message %q but got "%v"`, expMsg, r) - } - }() - rq.Push(testRecord{}) -} - -type testRecord struct { - Record - - id int -} diff --git a/internal/report.go b/internal/report.go index a0018d6..753bc88 100644 --- a/internal/report.go +++ b/internal/report.go @@ -5,8 +5,9 @@ import ( "errors" "fmt" "io" - "math/big" "time" + + "github.com/shopspring/decimal" ) type RecordReader interface { @@ -20,7 +21,7 @@ type ReportWriter interface { } func BuildReport(ctx context.Context, reader RecordReader, writer ReportWriter) error { - buys := make(map[string]*RecordQueue) + buys := make(map[string]*FillerQueue) for { select { @@ -37,7 +38,7 @@ func BuildReport(ctx context.Context, reader RecordReader, writer ReportWriter) buyQueue, ok := buys[rec.Symbol()] if !ok { - buyQueue = new(RecordQueue) + buyQueue = new(FillerQueue) buys[rec.Symbol()] = buyQueue } @@ -49,42 +50,41 @@ func BuildReport(ctx context.Context, reader RecordReader, writer ReportWriter) } } -func processRecord(ctx context.Context, q *RecordQueue, rec Record, writer ReportWriter) error { +func processRecord(ctx context.Context, q *FillerQueue, rec Record, writer ReportWriter) error { switch rec.Side() { case SideBuy: - q.Push(rec) + q.Push(NewFiller(rec)) case SideSell: - unmatchedQty := new(big.Float).Copy(rec.Quantity()) - zero := new(big.Float) + unmatchedQty := rec.Quantity() - for unmatchedQty.Cmp(zero) > 0 { + for unmatchedQty.IsPositive() { buy, ok := q.Peek() if !ok { return ErrInsufficientBoughtVolume } - var matchedQty *big.Float - if buy.Quantity().Cmp(unmatchedQty) > 0 { - matchedQty = unmatchedQty - buy.Quantity().Sub(buy.Quantity(), unmatchedQty) - } else { - matchedQty = buy.Quantity() - q.Pop() + matchedQty, filled := buy.Fill(unmatchedQty) + + if filled { + _, ok := q.Pop() + if !ok { + return fmt.Errorf("pop empty filler queue") + } } - unmatchedQty.Sub(unmatchedQty, matchedQty) + unmatchedQty = unmatchedQty.Sub(matchedQty) - sellValue := new(big.Float).Mul(matchedQty, rec.Price()) - buyValue := new(big.Float).Mul(matchedQty, buy.Price()) + buyValue := matchedQty.Mul(buy.Price()) + sellValue := matchedQty.Mul(rec.Price()) err := writer.Write(ctx, ReportItem{ BuyValue: buyValue, BuyTimestamp: buy.Timestamp(), SellValue: sellValue, SellTimestamp: rec.Timestamp(), - Fees: new(big.Float).Add(buy.Fees(), rec.Fees()), - Taxes: new(big.Float).Add(buy.Taxes(), rec.Fees()), + Fees: buy.Fees().Add(rec.Fees()), + Taxes: buy.Taxes().Add(rec.Fees()), }) if err != nil { return fmt.Errorf("write report item: %w", err) @@ -99,16 +99,14 @@ func processRecord(ctx context.Context, q *RecordQueue, rec Record, writer Repor } type ReportItem struct { - BuyValue *big.Float + BuyValue decimal.Decimal BuyTimestamp time.Time - SellValue *big.Float + SellValue decimal.Decimal SellTimestamp time.Time - Fees *big.Float - Taxes *big.Float + Fees decimal.Decimal + Taxes decimal.Decimal } -func (ri ReportItem) RealisedPnL() *big.Float { - return new(big.Float).Sub(ri.SellValue, ri.BuyValue) +func (ri ReportItem) RealisedPnL() decimal.Decimal { + return ri.SellValue.Sub(ri.BuyValue) } - -var ErrInsufficientBoughtVolume = fmt.Errorf("insufficient bought volume") diff --git a/internal/report_test.go b/internal/report_test.go index 231b213..70b64c9 100644 --- a/internal/report_test.go +++ b/internal/report_test.go @@ -2,17 +2,18 @@ package internal_test import ( "context" + "fmt" "io" - "math/big" "testing" "time" "github.com/nmoniz/any2anexoj/internal" "github.com/nmoniz/any2anexoj/internal/mocks" + "github.com/shopspring/decimal" "go.uber.org/mock/gomock" ) -func TestReporter_Run(t *testing.T) { +func TestBuildReport(t *testing.T) { now := time.Now() ctrl := gomock.NewController(t) @@ -32,13 +33,13 @@ func TestReporter_Run(t *testing.T) { }).Times(3) writer := mocks.NewMockReportWriter(ctrl) - writer.EXPECT().Write(gomock.Any(), gomock.Eq(internal.ReportItem{ - BuyValue: new(big.Float).SetFloat64(200.0), + writer.EXPECT().Write(gomock.Any(), eqReportItem(internal.ReportItem{ + BuyValue: decimal.NewFromFloat(200.0), BuyTimestamp: now, - SellValue: new(big.Float).SetFloat64(250.0), + SellValue: decimal.NewFromFloat(250.0), SellTimestamp: now.Add(1), - Fees: new(big.Float), - Taxes: new(big.Float), + Fees: decimal.Decimal{}, + Taxes: decimal.Decimal{}, })).Times(1) gotErr := internal.BuildReport(t.Context(), reader, writer) @@ -49,12 +50,47 @@ func TestReporter_Run(t *testing.T) { func mockRecord(ctrl *gomock.Controller, price, quantity float64, side internal.Side, ts time.Time) *mocks.MockRecord { rec := mocks.NewMockRecord(ctrl) - rec.EXPECT().Price().Return(big.NewFloat(price)).AnyTimes() - rec.EXPECT().Quantity().Return(big.NewFloat(quantity)).AnyTimes() + rec.EXPECT().Price().Return(decimal.NewFromFloat(price)).AnyTimes() + rec.EXPECT().Quantity().Return(decimal.NewFromFloat(quantity)).AnyTimes() rec.EXPECT().Side().Return(side).AnyTimes() rec.EXPECT().Symbol().Return("TEST").AnyTimes() rec.EXPECT().Timestamp().Return(ts).AnyTimes() - rec.EXPECT().Fees().Return(new(big.Float)).AnyTimes() - rec.EXPECT().Taxes().Return(new(big.Float)).AnyTimes() + rec.EXPECT().Fees().Return(decimal.Decimal{}).AnyTimes() + rec.EXPECT().Taxes().Return(decimal.Decimal{}).AnyTimes() return rec } + +func eqReportItem(ri internal.ReportItem) ReportItemMatcher { + return ReportItemMatcher{ + ReportItem: ri, + } +} + +type ReportItemMatcher struct { + internal.ReportItem +} + +// Matches implements gomock.Matcher. +func (m ReportItemMatcher) Matches(x any) bool { + if x == nil { + return false + } + + switch other := x.(type) { + case internal.ReportItem: + return m.BuyValue.Equal(other.BuyValue) && + m.BuyTimestamp.Equal(other.BuyTimestamp) && + m.SellValue.Equal(other.SellValue) && + m.SellTimestamp.Equal(other.SellTimestamp) && + m.Fees.Equal(other.Fees) && + m.Taxes.Equal(other.Taxes) + default: + return false + } +} + +func (m ReportItemMatcher) String() string { + return fmt.Sprintf("is equivalent to %v", m.ReportItem) +} + +var _ gomock.Matcher = (*ReportItemMatcher)(nil) diff --git a/internal/stdout.go b/internal/stdout.go index 31000a6..0ced987 100644 --- a/internal/stdout.go +++ b/internal/stdout.go @@ -29,6 +29,6 @@ func NewReportLogger(w io.Writer) *ReportLogger { func (rl *ReportLogger) Write(_ context.Context, ri ReportItem) error { rl.counter++ - _, err := fmt.Fprintf(rl.writer, "%6d - realised %+f on %s\n", rl.counter, ri.RealisedPnL(), ri.SellTimestamp.Format(time.RFC3339)) + _, err := fmt.Fprintf(rl.writer, "%6d: realised %s on %s\n", rl.counter, ri.RealisedPnL().String(), ri.SellTimestamp.Format(time.RFC3339)) return err } diff --git a/internal/stdout_test.go b/internal/stdout_test.go index 5879732..fd181e9 100644 --- a/internal/stdout_test.go +++ b/internal/stdout_test.go @@ -3,11 +3,11 @@ package internal_test import ( "bytes" "fmt" - "math/big" "testing" "time" "github.com/nmoniz/any2anexoj/internal" + "github.com/shopspring/decimal" ) func TestReportLogger_Write(t *testing.T) { @@ -25,45 +25,45 @@ func TestReportLogger_Write(t *testing.T) { name: "single item positive", items: []internal.ReportItem{ { - BuyValue: new(big.Float).SetFloat64(100.0), - SellValue: new(big.Float).SetFloat64(200.0), + BuyValue: decimal.NewFromFloat(100.0), + SellValue: decimal.NewFromFloat(200.0), SellTimestamp: tNow, }, }, want: []string{ - fmt.Sprintf("%6d - realised +100.000000 on %s\n", 1, tNow.Format(time.RFC3339)), + fmt.Sprintf("%6d: realised 100 on %s\n", 1, tNow.Format(time.RFC3339)), }, }, { name: "single item negative", items: []internal.ReportItem{ { - BuyValue: new(big.Float).SetFloat64(200.0), - SellValue: new(big.Float).SetFloat64(150.0), + BuyValue: decimal.NewFromFloat(200.0), + SellValue: decimal.NewFromFloat(150.0), SellTimestamp: tNow, }, }, want: []string{ - fmt.Sprintf("%6d - realised -50.000000 on %s\n", 1, tNow.Format(time.RFC3339)), + fmt.Sprintf("%6d: realised -50 on %s\n", 1, tNow.Format(time.RFC3339)), }, }, { name: "multiple items", items: []internal.ReportItem{ { - BuyValue: new(big.Float).SetFloat64(100.0), - SellValue: new(big.Float).SetFloat64(200.0), + BuyValue: decimal.NewFromFloat(100.0), + SellValue: decimal.NewFromFloat(200.0), SellTimestamp: tNow, }, { - BuyValue: new(big.Float).SetFloat64(200.0), - SellValue: new(big.Float).SetFloat64(150.0), + BuyValue: decimal.NewFromFloat(200.0), + SellValue: decimal.NewFromFloat(150.0), SellTimestamp: tNow.Add(1), }, }, want: []string{ - fmt.Sprintf("%6d - realised +100.000000 on %s\n", 1, tNow.Format(time.RFC3339)), - fmt.Sprintf("%6d - realised -50.000000 on %s\n", 2, tNow.Add(1).Format(time.RFC3339)), + fmt.Sprintf("%6d: realised 100 on %s\n", 1, tNow.Format(time.RFC3339)), + fmt.Sprintf("%6d: realised -50 on %s\n", 2, tNow.Add(1).Format(time.RFC3339)), }, }, } diff --git a/internal/trading212/record.go b/internal/trading212/record.go index 541da5d..cae8db5 100644 --- a/internal/trading212/record.go +++ b/internal/trading212/record.go @@ -5,21 +5,21 @@ import ( "encoding/csv" "fmt" "io" - "math/big" "strings" "time" "github.com/nmoniz/any2anexoj/internal" + "github.com/shopspring/decimal" ) type Record struct { symbol string side internal.Side - quantity *big.Float - price *big.Float + quantity decimal.Decimal + price decimal.Decimal timestamp time.Time - fees *big.Float - taxes *big.Float + fees decimal.Decimal + taxes decimal.Decimal } func (r Record) Symbol() string { @@ -30,11 +30,11 @@ func (r Record) Side() internal.Side { return r.side } -func (r Record) Quantity() *big.Float { +func (r Record) Quantity() decimal.Decimal { return r.quantity } -func (r Record) Price() *big.Float { +func (r Record) Price() decimal.Decimal { return r.price } @@ -42,11 +42,11 @@ func (r Record) Timestamp() time.Time { return r.timestamp } -func (r Record) Fees() *big.Float { +func (r Record) Fees() decimal.Decimal { return r.fees } -func (r Record) Taxes() *big.Float { +func (r Record) Taxes() decimal.Decimal { return r.taxes } @@ -123,24 +123,23 @@ func (rr RecordReader) ReadRecord(_ context.Context) (internal.Record, error) { price: price, timestamp: ts, fees: convertionFee, - taxes: new(big.Float).Add(stampDutyTax, frenchTxTax), + taxes: stampDutyTax.Add(frenchTxTax), }, nil } } // parseFloat attempts to parse a string using a standard precision and rounding mode. // Using this function helps avoid issues around converting values due to sligh parameter changes. -func parseDecimal(s string) (*big.Float, error) { - f, _, err := big.ParseFloat(s, 10, 128, big.ToZero) - return f, err +func parseDecimal(s string) (decimal.Decimal, error) { + return decimal.NewFromString(s) } // parseOptinalDecimal behaves the same as parseDecimal but returns 0 when len(s) is 0 instead of // error. // Using this function helps avoid issues around converting values due to sligh parameter changes. -func parseOptinalDecimal(s string) (*big.Float, error) { +func parseOptinalDecimal(s string) (decimal.Decimal, error) { if len(s) == 0 { - return new(big.Float), nil + return decimal.Decimal{}, nil } return parseDecimal(s) diff --git a/internal/trading212/record_test.go b/internal/trading212/record_test.go index 10d71b9..b464f1a 100644 --- a/internal/trading212/record_test.go +++ b/internal/trading212/record_test.go @@ -3,11 +3,11 @@ package trading212 import ( "bytes" "io" - "math/big" "testing" "time" "github.com/nmoniz/any2anexoj/internal" + "github.com/shopspring/decimal" ) func TestRecordReader_ReadRecord(t *testing.T) { @@ -136,7 +136,7 @@ func TestRecordReader_ReadRecord(t *testing.T) { } } -func ShouldParseDecimal(t testing.TB, sf string) *big.Float { +func ShouldParseDecimal(t testing.TB, sf string) decimal.Decimal { t.Helper() bf, err := parseDecimal(sf)