Skip to content

Commit

Permalink
Update GitHub action
Browse files Browse the repository at this point in the history
Also fix the runs for Tensorflow backend
  • Loading branch information
lettercode committed Jun 15, 2024
1 parent c5cab47 commit 1335cb8
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 16 deletions.
48 changes: 40 additions & 8 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@ permissions:
contents: read

jobs:
build:

build_pytorch_backend:
runs-on: ubuntu-latest

container:
image: pytorch/pytorch:2.3.1-cuda11.8-cudnn8-runtime
env:
KERAS_BACKEND: torch
volumes:
- my_docker_volume:/volume_mount

steps:
- uses: actions/checkout@v3
- name: Set up Python 3.10
uses: actions/setup-python@v3
with:
python-version: "3.10"

- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -35,6 +38,35 @@ jobs:
echo "PYTHONPATH=." >> $GITHUB_ENV
- name: Test with pytest
run: |
pytest ncps/tests/test_tf.py
pytest ncps/tests/test_torch.py
pytest ncps/tests/test_keras.py
build_tensorflow_backend:
runs-on: ubuntu-latest

container:
image: tensorflow/tensorflow:2.16.1
env:
KERAS_BACKEND: tensorflow
volumes:
- my_docker_volume:/volume_mount

steps:
- uses: actions/checkout@v3

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: set pythonpath
run: |
echo "PYTHONPATH=." >> $GITHUB_ENV
- name: Test with pytest
run: |
pytest ncps/tests/test_keras.py
21 changes: 13 additions & 8 deletions ncps/tests/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Run on CPU
os.environ["KERAS_BACKEND"] = "torch"
# os.environ["KERAS_BACKEND"] = "torch"
# os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
Expand Down Expand Up @@ -230,6 +230,7 @@ def test_ltc_rnn():
model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mean_squared_error")
model.fit(x=data_x, y=data_y, batch_size=1, epochs=3)


def test_ncps():
input_size = 8

Expand All @@ -238,8 +239,9 @@ def test_ncps():
data = keras.random.normal([3, input_size])
hx = keras.ops.zeros([3, wiring.units])
output, hx = ltc_cell(data, hx)
assert output.size() == (3, 4)
assert hx[0].size() == (3, wiring.units)
assert output.shape == (3, 4)
assert hx[0].shape == (3, wiring.units)


def test_ncp_sizes():
wiring = ncps.wirings.NCP(10, 10, 8, 6, 6, 4, 6)
Expand All @@ -248,25 +250,28 @@ def test_ncp_sizes():
output = rnn(data)
assert wiring.synapse_count > 0
assert wiring.sensory_synapse_count > 0
assert output.size() == (5, 8)
assert output.shape == (5, 8)


def test_auto_ncp():
wiring = ncps.wirings.AutoNCP(16, 4)
rnn = LTC(wiring)
data = keras.random.normal([5, 3, 8])
output = rnn(data)
assert output.size() == (5, 4)
assert output.shape == (5, 4)


def test_ncp_cfc():
wiring = ncps.wirings.NCP(10, 10, 8, 6, 6, 4, 6)
rnn = CfC(wiring)
data = keras.random.normal([5, 3, 8])
output = rnn(data)
assert output.size() == (5, 8)
assert output.shape == (5, 8)


def test_auto_ncp_cfc():
wiring = ncps.wirings.AutoNCP(28, 10)
rnn = CfC(wiring)
data = keras.random.normal([5, 3, 8])
output = rnn(data)
assert output.size() == (5, 10)
assert output.shape == (5, 10)

0 comments on commit 1335cb8

Please sign in to comment.