Skip to content

Commit

Permalink
Working nn driver
Browse files Browse the repository at this point in the history
  • Loading branch information
jzbontar committed Jan 22, 2016
1 parent d3ff28d commit f84e113
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 7 deletions.
3 changes: 3 additions & 0 deletions src/drivers/jz/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,7 @@ PKGSUBDIRS = ${SHIPSUBDIRS}
src-robots-jz_PKGFILES = $(shell find * -maxdepth 0 -type f -print)
src-robots-jz_PKGDIR = ${PACKAGE}-${VERSION}/$(subst ${TORCS_BASE},,$(shell pwd))

LIBS = -L/home/jure/torch/install/lib -lluajit -lluaT -lTH
COMPILFLAGS = -I/home/jure/torch/install/include/ -I/home/jure/torch/install/include/TH/

include ${MAKE_DEFAULT}
45 changes: 40 additions & 5 deletions src/drivers/jz/jz.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <err.h>

#include <tgf.h>
#include <track.h>
Expand All @@ -33,6 +34,14 @@

#include <tgfclient.h>

extern "C" {
#include "lua.h"
#include "lualib.h"
#include "lauxlib.h"
}
#include "luaT.h"
#include "TH.h"

static tTrack *curTrack;

static void initTrack(int index, tTrack* track, void *carHandle, void **carParmHandle, tSituation *s);
Expand All @@ -42,13 +51,23 @@ static void endrace(int index, tCarElt *car, tSituation *s);
static void shutdown(int index);
static int InitFuncPt(int index, void *pt);

static lua_State *L = NULL;

/*
* Module entry point
*/
extern "C" int
jz(tModInfo *modInfo)
{
L = luaL_newstate();
luaL_openlibs(L);
if (luaL_loadfile(L, "/home/jure/build/torcs-1.3.6/src/drivers/jz/main.lua")) {
err(0, "luaL_loadfile");
}
if (lua_pcall(L, 0, 0, 0)) {
err(0, "lua_pcall");
}

memset(modInfo, 0, 10*sizeof(tModInfo));

modInfo->name = strdup("jz"); /* name of the module (short) */
Expand Down Expand Up @@ -94,6 +113,9 @@ newrace(int index, tCarElt* car, tSituation *s)
extern tRmInfo *ReInfo;
void reMovieCapture(void *);

#define WIDTH 160
#define HEIGHT 120
unsigned char img[3 * WIDTH * HEIGHT];
unsigned long long tick1;

/* Drive during race. */
Expand All @@ -102,18 +124,30 @@ drive(int index, tCarElt* car, tSituation *s)
{
float angle;

if (tick1 == 150) {
reMovieCapture(NULL);
}
glReadPixels(0, 0, WIDTH, WIDTH, GL_RGB, GL_UNSIGNED_BYTE, (GLvoid*)img);

THByteStorage *storage = THByteStorage_newWithData(img, 3 * WIDTH * HEIGHT);
THByteTensor *tensor = THByteTensor_newWithStorage1d(storage, 0, 3 * WIDTH * HEIGHT, 1);
lua_getglobal(L, "drive");
luaT_pushudata(L, (void *)tensor, "torch.ByteTensor");
lua_pcall(L, 1, 1, 0);
angle = lua_tonumber(L, -1);

memset((void *)&car->ctrl, 0, sizeof(tCarCtrl));
car->ctrl.steer = angle;
car->ctrl.gear = 1;
car->ctrl.accelCmd = 0.3;
car->ctrl.brakeCmd = 0.0;

/*
angle = RtTrackSideTgAngleL(&(car->_trkPos)) - car->_yaw;
NORM_PI_PI(angle);
angle -= (car->_trkPos.toMiddle / car->_trkPos.seg->width);
/* predict with nn */
// predict with nn
car->targets[0] = angle / car->_steerLock;
/* actual driving command */
// actual driving command
float d = sin((double)tick1 / 500) / 2.;
angle = (angle + d) / car->_steerLock;
Expand All @@ -122,6 +156,7 @@ drive(int index, tCarElt* car, tSituation *s)
car->ctrl.gear = 1;
car->ctrl.accelCmd = 0.3;
car->ctrl.brakeCmd = 0.0;
*/

tick1++;
}
Expand Down
16 changes: 16 additions & 0 deletions src/drivers/jz/main.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
require('cunn')
require('cudnn')
require('image')

local base_dir = '/home/jure/devel/torcs'
local net = torch.load(base_dir .. '/net/net.t7')
local x_batch = torch.CudaTensor(1, 3, 120, 160)
local img_mean = torch.load(base_dir .. '/data/img_mean.t7'):view(1, 3, 1, 1):expandAs(x_batch):cuda()

function drive(img)
img = image.vflip(img:view(120, 160, 3):permute(3, 1, 2))
x_batch:copy(img):add(-1, img_mean)
net:forward(x_batch)
angle = net.output[1]
return angle
end
2 changes: 1 addition & 1 deletion src/interfaces/raceman.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ typedef int (*tfRmRunState) (struct RmInfo *);


#define RCM_MAX_DT_SIMU 0.002
#define RCM_MAX_DT_ROBOTS 0.02
#define RCM_MAX_DT_ROBOTS 0.1

/** General info on current race */
typedef struct {
Expand Down
3 changes: 2 additions & 1 deletion src/linux/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ SOLIBS = -lracescreens \
-lplibul \
-lraceengine \
-lmusicplayer \
-llearning
-llearning \
-L/home/jure/torch/install/lib -lluajit -lluaT -lTH

EXPDIR = include

Expand Down

1 comment on commit f84e113

@etienne87
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello! Could you explain a little how you do the training? Do you do it online while the game is running? Thanks a lot in advance

Please sign in to comment.