Skip to content

Commit

Permalink
Merge pull request #12 from Algebra8/itertools-count-float-input
Browse files Browse the repository at this point in the history
Itertools count float input
  • Loading branch information
tdakkota authored Apr 26, 2021
2 parents d3af75a + b847cb7 commit e381dd5
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 51 deletions.
149 changes: 99 additions & 50 deletions itertools.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,98 @@ import (
"go.starlark.net/starlark"
)

type countObject struct {
cnt int
step int
frozen bool
value starlark.Value
// float or int type to allow mixed inputs.
type floatOrInt struct {
value starlark.Value
}

// Unpacker for floatOrInt.
func (p *floatOrInt) Unpack(v starlark.Value) error {
switch v := v.(type) {
case starlark.Int:
p.value = v
return nil
case starlark.Float:
p.value = v
return nil
}
return fmt.Errorf("got %s, want float or int", v.Type())
}

func (f *floatOrInt) add(n floatOrInt) error {
switch _f := f.value.(type) {
case starlark.Int:
switch _n := n.value.(type) {
// int + int
case starlark.Int:
f.value = _f.Add(_n)
return nil
// int + float
case starlark.Float:
_n += _f.Float()
f.value = _n
return nil
}
case starlark.Float:
switch _n := n.value.(type) {
// float + int
case starlark.Int:
_f += _n.Float()
f.value = _f
return nil
// float + float
case starlark.Float:
_f += _n
f.value = _f
return nil
}
}

return fmt.Errorf("error with addition: types are not int, float combos")
}

func (f *floatOrInt) String() string {
return f.value.String()
}

func newCountObject(cnt int, stepValue int) *countObject {
return &countObject{cnt: cnt, step: stepValue, value: starlark.MakeInt(cnt)}
// Iterator implementation for countObject.
type countIter struct {
co *countObject
}

func (co *countObject) String() string {
func (c *countIter) Next(p *starlark.Value) bool {
if c.co.frozen {
return false
}

*p = c.co.cnt.value

if e := c.co.cnt.add(c.co.step); e != nil {
return false
}

return true
}

func (c *countIter) Done() {}

// countObject implementation as a starlark.Value.
type countObject struct {
cnt, step floatOrInt
frozen bool
}

func (co countObject) String() string {
// As with the cpython implementation, we don't display
// step when it is an integer equal to 1.
if co.step == 1 {
return fmt.Sprintf("count(%v)", co.cnt)
// step when it is an integer equal to 1 (default step value).
step, ok := co.step.value.(starlark.Int)
if ok {
if x, ok := step.Int64(); ok && x == 1 {
return fmt.Sprintf("count(%v)", co.cnt.String())
}
}
return fmt.Sprintf("count(%v, %v)", co.cnt, co.step)

return fmt.Sprintf("count(%v, %v)", co.cnt.String(), co.step.String())
}

func (co *countObject) Type() string {
Expand All @@ -33,7 +107,6 @@ func (co *countObject) Type() string {
func (co *countObject) Freeze() {
if !co.frozen {
co.frozen = true
co.value.Freeze()
}
}

Expand All @@ -50,59 +123,35 @@ func (co *countObject) Iterate() starlark.Iterator {
return &countIter{co: co}
}

type countIter struct {
co *countObject
}

func (c *countIter) Next(p *starlark.Value) bool {
if c.co.frozen {
return false
}
*p = starlark.MakeInt(c.co.cnt)
c.co.cnt += c.co.step
return true
}

func (c *countIter) Done() {}

func count_(
thread *starlark.Thread,
_ *starlark.Builtin,
args starlark.Tuple,
kwargs []starlark.Tuple,
) (starlark.Value, error) {
var (
start int
step int
defaultStart = starlark.MakeInt(0)
defaultStep = starlark.MakeInt(1)
start floatOrInt
step floatOrInt
)

if err := starlark.UnpackPositionalArgs(
"count", args, kwargs, 0, &start, &step,
); err != nil {
return nil, fmt.Errorf(
"Got %v but expected NoneType or valid integer values for "+
"start and step, such as (0, 1).", args.String(),
"Got %v but expected no args, or one or two valid numbers",
args.String(),
)
}

const (
defaultStart = 0
defaultStep = 1
)
// The rules for populating the count object based on the number
// of args passed is as follows:
// 0 args -> default values for start and step
// 1 args -> arg defines start, default for step
// 2 args -> both start and step are defined by args
var co_ *countObject
switch nargs := len(args); {
case nargs == 0:
co_ = newCountObject(defaultStart, defaultStep)
case nargs == 1:
co_ = newCountObject(start, defaultStep)
default: // nargs == 2
co_ = newCountObject(start, step)
// Check if start or step require default values.
if start.value == nil {
start.value = defaultStart
}
if step.value == nil {
step.value = defaultStep
}

return co_, nil
return &countObject{cnt: start, step: step}, nil
}
94 changes: 93 additions & 1 deletion testdata/itertools.star
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,107 @@ def test_count():
assert.eq(str(c2), "count(11, 3)")
assert.eq(next(c2), 11)

# Negative args.
c3 = count(-5, -10)
assert.eq(str(c3), "count(-5, -10)")
assert.eq(next(c3), -5)
assert.eq(str(c3), "count(-15, -10)")
assert.eq(next(c3), -15)

c4 = count(5, -5)
assert.eq(str(c4), "count(5, -5)")
assert.eq(next(c4), 5)
assert.eq(str(c4), "count(0, -5)")
assert.eq(next(c4), 0)
assert.eq(str(c4), "count(-5, -5)")
assert.eq(next(c4), -5)

# Int start, float step.
c5 = count(0, 0.1)
assert.eq(str(c5), "count(0, 0.1)")
assert.eq(next(c5), 0)
assert.eq(str(c5), "count(0.1, 0.1)")
assert.eq(next(c5), 0.1)
assert.eq(str(c5), "count(0.2, 0.1)")
assert.eq(next(c5), 0.2)

# Float start, int step — this should be handled same as above
# but check to be exhaustive.
c6 = count(0.5, 5)
assert.eq(str(c6), "count(0.5, 5)")
assert.eq(next(c6), 0.5)
assert.eq(str(c6), "count(5.5, 5)")
assert.eq(next(c6), 5.5)
assert.eq(str(c6), "count(10.5, 5)")
assert.eq(next(c6), 10.5)

# This test may seem similar to c5 but is different because
# here step > 1. In the case that 0 < step < 1, fmt.Sprintf,
# which is used in String(), will display it as a float but
# may display it as an int if the proper flags aren't used.
c7 = count(5.0, 0.5)
assert.eq(str(c7), "count(5.0, 0.5)")
assert.eq(next(c7), 5.0)
assert.eq(str(c7), "count(5.5, 0.5)")
assert.eq(next(c7), 5.5)
assert.eq(str(c7), "count(6.0, 0.5)")
assert.eq(next(c7), 6.0)

# NaNs
c8 = count(0, float('nan'))
assert.eq(str(c8), "count(0, %s)" % (float('nan')))
assert.eq(next(c8), 0)
assert.eq(str(c8), "count(%s, %s)" % (float('nan'), float('nan')))
assert.eq(next(c8), float('nan'))
assert.eq(str(c8), "count(%s, %s)" % (float('nan'), float('nan')))
assert.eq(next(c8), float('nan'))

c9 = count(0, float("+inf"))
assert.eq(str(c9), "count(0, %s)" % (float("+inf")))
assert.eq(next(c9), 0)
assert.eq(str(c9), "count(%s, %s)" % (float("+inf"), float("+inf")))
assert.eq(next(c9), float("+inf"))
assert.eq(str(c9), "count(%s, %s)" % (float("+inf"), float("+inf")))
assert.eq(next(c9), float("+inf"))

c10 = count(0, float("-inf"))
assert.eq(str(c10), "count(0, %s)" % (float("-inf")))
assert.eq(next(c10), 0)
assert.eq(str(c10), "count(%s, %s)" % (float("-inf"), float("-inf")))
assert.eq(next(c10), float("-inf"))
assert.eq(str(c10), "count(%s, %s)" % (float("-inf"), float("-inf")))
assert.eq(next(c10), float("-inf"))

c11 = count(float("nan"), 2)
assert.eq(str(c11), "count(%s, 2)" % (float('nan')))
assert.eq(next(c11), float('nan'))
assert.eq(str(c11), "count(%s, 2)" % (float('nan')))
assert.eq(next(c11), float('nan'))

c12 = count(float("+inf"), 2)
assert.eq(str(c12), "count(%s, 2)" % (float('+inf')))
assert.eq(next(c12), float('+inf'))
assert.eq(str(c12), "count(%s, 2)" % (float('+inf')))
assert.eq(next(c12), float('+inf'))

c13 = count(float("-inf"), 2)
assert.eq(str(c13), "count(%s, 2)" % (float('-inf')))
assert.eq(next(c13), float('-inf'))
assert.eq(str(c13), "count(%s, 2)" % (float('-inf')))
assert.eq(next(c13), float('-inf'))

# Fails
z = ("a", "b")
# Non-numeric arg fails.
assert.fails(
lambda: count("a", "b"),
# fails uses match under the hood, which will use
# regexp.MatchString, so need to use raw pattern
# that MatchString would accept.
r'Got \(\"a\", \"b\"\)',
)

# Too many arg fails — should be handled by UnpackArgs but
# check to be exhaustive.
assert.fails(
lambda: count(1, 2, 3),
r'Got \(1, 2, 3\)'
Expand Down

0 comments on commit e381dd5

Please sign in to comment.