Skip to content

Commit

Permalink
support mindspore
Browse files Browse the repository at this point in the history
  • Loading branch information
lvyufeng committed May 31, 2024
1 parent 6181e1e commit fb3676a
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 0 deletions.
74 changes: 74 additions & 0 deletions einops/_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,3 +713,77 @@ def is_float_type(self, x):

def einsum(self, pattern, *x):
return self.tinygrad.Tensor.einsum(pattern, *x)

class MindSporeBackend(AbstractBackend):
framework_name = "mindspore"

def __init__(self):
import mindspore
from mindspore import ops

self.ms = mindspore
self.ms_ops = ops

def is_appropriate_type(self, tensor):
return isinstance(tensor, self.ms.Tensor)

def from_numpy(self, x):
variable = self.ms.Tensor.from_numpy(x)
return variable

def to_numpy(self, x):
return x.asnumpy()

def arange(self, start, stop):
return self.ms_ops.arange(start, stop, dtype=self.ms.int64)

def reduce(self, x, operation, reduced_axes):
for axis in sorted(reduced_axes, reverse=True):
if operation == "min":
x = x.min(axis=axis)
elif operation == "max":
x = x.max(axis=axis)
elif operation in ["sum", "mean", "prod", "any", "all"]:
x = getattr(x, operation)(axis=axis)
else:
raise NotImplementedError("Unknown reduction ", operation)
return x

def transpose(self, x, axes):
return x.transpose(axes)

def stack_on_zeroth_dimension(self, tensors: list):
return self.ms_ops.stack(tensors)

def add_axes(self, x, n_axes, pos2len):
repeats = [-1] * n_axes
for axis_position, axis_length in pos2len.items():
x = self.add_axis(x, axis_position)
repeats[axis_position] = axis_length
return x.broadcast_to(tuple(repeats))

def reshape(self, x, shape):
return x.reshape(shape).stub_sync()

def tile(self, x, repeats):
return x.tile(repeats)

def concat(self, tensors, axis: int):
return self.ms_ops.concat(tensors, axis=axis)

def add_axis(self, x, new_position):
return self.ms_ops.unsqueeze(x, new_position)

def is_float_type(self, x):
return x.dtype in [self.ms.float16, self.ms.float32, self.ms.float64]

def layers(self):
from .layers import mindspore

return mindspore

def einsum(self, pattern, *x):
out = self.ms_ops.einsum(pattern, *x)
if out.shape == (1,):
out = out.reshape(())
return out
57 changes: 57 additions & 0 deletions einops/layers/mindspore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Optional, Dict, cast

import mindspore
from mindspore import nn, ops
from mindspore.common.initializer import initializer, Uniform

from . import RearrangeMixin, ReduceMixin
from ._einmix import _EinmixMixin

__author__ = "Yufeng Lyu"


class Rearrange(RearrangeMixin, nn.Cell):
def construct(self, input):
return self._apply_recipe(input)


class Reduce(ReduceMixin, nn.Cell):
def construct(self, input):
return self._apply_recipe(input)


class EinMix(_EinmixMixin, nn.Cell):
def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound):
self.weight = mindspore.Parameter(
initializer(Uniform(weight_bound), weight_shape), requires_grad=True
)
if bias_shape is not None:
self.bias = mindspore.Parameter(
initializer(Uniform(weight_bound), bias_shape), requires_grad=True)
else:
self.bias = None

def _create_rearrange_layers(
self,
pre_reshape_pattern: Optional[str],
pre_reshape_lengths: Optional[Dict],
post_reshape_pattern: Optional[str],
post_reshape_lengths: Optional[Dict],
):
self.pre_rearrange = None
if pre_reshape_pattern is not None:
self.pre_rearrange = Rearrange(pre_reshape_pattern, **cast(dict, pre_reshape_lengths))

self.post_rearrange = None
if post_reshape_pattern is not None:
self.post_rearrange = Rearrange(post_reshape_pattern, **cast(dict, post_reshape_lengths))

def construct(self, input):
if self.pre_rearrange is not None:
input = self.pre_rearrange(input)
result = ops.einsum(self.einsum_pattern, input, self.weight)
if self.bias is not None:
result += self.bias
if self.post_rearrange is not None:
result = self.post_rearrange(result)
return result
1 change: 1 addition & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def main():
# "paddle": ["paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html"],
"paddle": ["paddlepaddle"],
"oneflow": ["oneflow==0.9.0"],
"mindspore": ["mindspore"]
}

usage = f"""
Expand Down
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def collect_test_backends(symbolic=False, layers=False) -> List[_backends.Abstra
_backends.OneFlowBackend,
_backends.PaddleBackend,
_backends.CupyBackend,
_backends.MindSporeBackend,
]
else:
backend_types = [
Expand Down
1 change: 1 addition & 0 deletions tests/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def test_layer():
"cupy",
"tensorflow.keras",
"paddle",
"mindspore"
]


Expand Down
44 changes: 44 additions & 0 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,47 @@ def eval_at_point(params):
# check serialization
fbytes = flax.serialization.to_bytes(params)
_loaded = flax.serialization.from_bytes(params, fbytes)

def create_mindspore_model(use_reduce=False):
if not is_backend_tested("mindspore"):
pytest.skip()
else:
from mindspore.nn import SequentialCell, Conv2d, MaxPool2d, Dense, ReLU
from einops.layers.mindspore import Rearrange, Reduce, EinMix

return SequentialCell(
Conv2d(3, 6, kernel_size=(5, 5)),
Reduce("b c (h h2) (w w2) -> b c h w", "max", h2=2, w2=2) if use_reduce else MaxPool2d(kernel_size=2),
Conv2d(6, 16, kernel_size=(5, 5)),
Reduce("b c (h h2) (w w2) -> b c h w", "max", h2=2, w2=2),
Rearrange("b c h w -> b (c h w)"),
Dense(16 * 5 * 5, 120),
ReLU(),
Dense(120, 84),
ReLU(),
EinMix("b c1 -> (b c2)", weight_shape="c1 c2", bias_shape="c2", c1=84, c2=84),
EinMix("(b c2) -> b c3", weight_shape="c2 c3", bias_shape="c3", c2=84, c3=84),
Dense(84, 10),
)


def test_mindspore_layer():
if not is_backend_tested("mindspore"):
pytest.skip()
else:
# checked that torch present
import numpy as np
import mindspore
from mindspore import ops

model1 = create_torch_model(use_reduce=True)
model2 = create_torch_model(use_reduce=False)
input = ops.randn([10, 3, 32, 32])
# random models have different predictions
assert not np.allclose(model1(input).asnumpy(), model2(input).asnumpy())

with tempfile.TemporaryDirectory() as dir:
filename = f"{dir}/model.ckpt"
mindspore.save_checkpoint(model1, filename)
mindspore.load_checkpoint(filename, model2)
assert np.allclose(model1(input).asnumpy(), model2(input).asnumpy())
2 changes: 2 additions & 0 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,8 @@ def test_reduction_stress_imperatives():
dtype = "float64"
coincide = numpy.allclose
max_dim = 11
if "mindspore" in backend.framework_name:
max_dim = 7
if "oneflow" in backend.framework_name:
max_dim = 7
if "paddle" in backend.framework_name:
Expand Down

0 comments on commit fb3676a

Please sign in to comment.