Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Equinox #25

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions deltapv/dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
'''
DEPRECATED
'''


import dataclasses
import jax

Expand Down
18 changes: 18 additions & 0 deletions deltapv/equinox_objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
def eqx_dataclass(data_clz):
# Uses equinox to create a new class with the same fields as the original
# and getter methods for pytrees

class clz(data_clz):

items = data_clz.__dataclass_fields__.items()

def iterate_clz(self, x):
# iterates the class and returns a tuple of every field
return tuple(getattr(x, name) for name in self.items)

def clz_from_iterable(self, data):
# creates a new class from a tuple of fields
kwargs = dict(zip(self.items, data))
return data_clz(**kwargs)

return clz
27 changes: 14 additions & 13 deletions deltapv/objects.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from deltapv import dataclasses, util
from deltapv import equinox_objects, util
from jax import numpy as jnp
from typing import Union
import equinox as eqx

Array = util.Array
f64 = util.f64
i64 = util.i64


@dataclasses.dataclass
class PVDesign:
@equinox_objects.eqx_dataclass
class PVDesign(eqx.Module):

grid: Array
eps: Array
Expand All @@ -35,8 +36,8 @@ class PVDesign:
PhiML: f64


@dataclasses.dataclass
class PVCell:
@equinox_objects.eqx_dataclass
class PVCell(eqx.Module):

dgrid: Array
eps: Array
Expand Down Expand Up @@ -70,15 +71,15 @@ def zero_cell(n: i64) -> PVCell:
return zc


@dataclasses.dataclass
class LightSource:
@equinox_objects.eqx_dataclass
class LightSource(eqx.Module):

Lambda: Array = jnp.ones(1)
P_in: Array = jnp.zeros(1)


@dataclasses.dataclass
class Material:
@equinox_objects.eqx_dataclass
class Material(eqx.Module):
eps: f64 = f64(1)
Chi: f64 = f64(1)
Eg: f64 = f64(1)
Expand All @@ -99,8 +100,8 @@ def __iter__(self):
return self.__dict__.items().__iter__()


@dataclasses.dataclass
class Potentials:
@equinox_objects.eqx_dataclass
class Potentials(eqx.Module):
phi: Array
phi_n: Array
phi_p: Array
Expand All @@ -112,8 +113,8 @@ def zero_pot(n: i64) -> Potentials:
return zp


@dataclasses.dataclass
class Boundary:
@equinox_objects.eqx_dataclass
class Boundary(eqx.Module):
phi0: f64
phiL: f64
neq0: f64
Expand Down
2 changes: 1 addition & 1 deletion deltapv/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from jax import numpy as jnp, jit, custom_jvp, grad
from jax.experimental import optimizers
from jax.example_libraries import optimizers
import jax
import numpy as np
from deltapv import spline, simulator
Expand Down
227 changes: 227 additions & 0 deletions neuralnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.neural_network import MLPRegressor as nn
import deltapv as dpv
from jax import numpy as jnp
from jax import grad, value_and_grad
import numpy as np
import jax

'''
This model grabs data and predicts the material properties then feeds it into
the dpv solver to get the predicted result. Then, performs back propagation to
find the changes in the material properties that get it closer to expected
result and from that finds the changes in the model to get it closer to changed values of
material properties.

'''


#####################
# Solving dpv #
#####################
def create_design(params):
'''
Given parameters that form a solar cell, create a deltapv object that
represents the design
'''
L = 3e-4
J = 5e-6
Chi=params[0]
Eg=params[1]
eps=params[2],
Nc=params[3],
Nv=params[4],
mn=params[5],
mp=params[6],
Et=params[7],
tn=params[8],
tp=params[9],
A=params[10]


material = dpv.create_material(Eg=Eg,
Chi=Chi,
eps=eps,
Nc=Nc,
Nv=Nv,
mn=mn,
mp=mp,
Et=Et,
tn=tn,
tp=tp,
A=A)


des = dpv.make_design(n_points=500,
Ls=[J, L - J],
mats=[material, material],
Ns=[1e17, -1e15],
Snl=1e7,
Snr=0,
Spl=0,
Spr=1e7)
return des



def f(params):
'''
Given a set of params that construct a solar cell, returns the efficiency
of that solar cell
'''
des = create_design(params)
results = dpv.simulate(des, verbose=False)
eff = results["eff"] * 100
return eff



# df is a tuple of f and the gradient of f, two functions
df = value_and_grad(f)



def f_np(x):
'''
f_np(x) takes x, a set of parameters representing material properties,
and returns the efficiency and gradient of efficiency with respect
to each property
'''
y, dy = df(x)
result = float(y), np.array(dy)
return result


##########################
# Data Formatting #
##########################
mats = []




#####################
# Neural Net #
#####################
def InitializeWeights(layer_sizes, seed):
weights = []

for i, units in enumerate(layer_sizes):
if i==0:
w = jax.random.uniform(key=seed, shape=(units, features), minval=-1.0, maxval=1.0, dtype=jnp.float32)
else:
w = jax.random.uniform(key=seed, shape=(units, layer_sizes[i-1]), minval=-1.0, maxval=1.0,
dtype=jnp.float32)

b = jax.random.uniform(key=seed, minval=-1.0, maxval=1.0, shape=(units,), dtype=jnp.float32)

weights.append([w,b])

return weights


def Relu(X):
'''
Activation function that returns max(0,x) for all x in X
'''
return jnp.maximum(X, jnp.zeros_like(X))



def LinearLayer(layer_weights, input_data, activation=lambda x: x):
'''
Computes one layer of the neural network, to be used in ForwardPass
'''
w, b = layer_weights
out = jnp.dot(input_data, w.T) + b
return activation(out)



def ForwardPass(weights, input_data):
'''
Passes the data through the neural network and returns the output params
'''
layer_out = input_data

for i in range(len(weights[:-1])):
layer_out = LinearLayer(weights[i], layer_out, Relu)

preds = LinearLayer(weights[-1], layer_out)

return preds.squeeze()



def MeanSquaredErrorLoss(weights, input_data, actual):
preds = ForwardPass(weights, input_data)
return jnp.power(actual - preds, 2).mean()



def CalculateGradients(weights, input_data, actual):
Grad_MSELoss = grad(MeanSquaredErrorLoss)
gradients = Grad_MSELoss(weights, input_data, actual)
return gradients



def TrainModel(weights, X, y, learning_rate, epochs):
'''
Trains on one data point per epoch, chosen randomly from the training set
'''
for i in range(epochs):
rand_mat = np.random.randint(0,len(mats))

# predicts the properties for every material
pred_props = ForwardPass(X[rand_mat])

# turns these predictions into predictions of efficiency
pred_eff, grad_pred_eff = f_np(create_design(pred_props))

# based on the predictions of efficiency, produces a list of expected material property values
alpha = 0.05
exp = pred_props + alpha * (y - pred_eff) * grad_pred_eff

# Finally, update the neural net with the expected output
loss = MeanSquaredErrorLoss(weights, X, exp)
gradients = CalculateGradients(weights, X)

## Update Weights
for j in range(len(weights)):
weights[j][0] -= learning_rate * gradients[j][0] ## Update Weights
weights[j][1] -= learning_rate * gradients[j][1] ## Update Biases

if i%5 ==0: ## Print MSE every 5 epochs
print("MSE : {:.2f}".format(loss))



if __name__ == '__main__':
pass





# REFERENCE:
# https://coderzcolumn.com/tutorials/artificial-intelligence/guide-to-create-simple-neural-networks-using-jax

'''
def find_params_for_nn():

rf_grid_params = {'n_estimators': [250,400,550], 'learning_rate': [0.05, 0.15, 0.25], 'max_depth': [1,2,3,4,5]}
grid = GridSearchCV(estimator = nn(), param_grid = rf_grid_params, refit = True, verbose = 2, cv = 5)
grid.fit(X_train, y_train)

mlp = nn(hidden_layer_sizes=(8,8,8), activation='relu', solver='adam', max_iter=200)
mlp.fit(X_train,y_train)

predict_train = mlp.predict(X_train)
predict_test = mlp.predict(X_test)


'''



6 changes: 3 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
import deltapv as dpv
from jax import numpy as jnp
#from jax import numpy as jnp
import numpy as np
from scipy.optimize import minimize
from optimize import psc
Expand Down Expand Up @@ -50,8 +50,8 @@ def test_iv(self):
0.008610345709349041, -0.018267911703588706
]

self.assertTrue(jnp.allclose(v, v_correct), "Voltages do not match!")
self.assertTrue(jnp.allclose(j, j_correct), "Currents do not match!")
self.assertTrue(np.allclose(v, v_correct), "Voltages do not match!")
self.assertTrue(np.allclose(j, j_correct), "Currents do not match!")

def test_psc(self):
bounds = [(1, 5), (1, 5), (1, 20), (17, 20), (17, 20), (0, 3), (0, 3),
Expand Down