-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.go
187 lines (161 loc) · 4.95 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
package main
import (
"errors"
"flag"
"fmt"
"gotftp/goftp"
"log"
"net"
"path/filepath"
"sync"
"time"
)
// TODO: Deal with the standard binary option as a first step - netascii next
// Set up some of the globals we're going to use.
var tftpDirectory string
// Flags
var port int
var enableWrites bool
// SessionKey will simply give us the key used to index active sessions
func SessionKey(a *net.UDPAddr) string {
return fmt.Sprintf("%v:%v", a.IP, a.Port)
}
func main() {
theLogger := log.Default()
readFlags()
theLogger.Printf("Starting TFTP server on %v:%v\n\n", "localhost", port)
s, err := net.ResolveUDPAddr("udp4", fmt.Sprintf(":%v", port))
if err != nil {
theLogger.Fatal(err)
}
connection, err := net.ListenUDP("udp4", s)
if err != nil {
theLogger.Fatal(err)
}
defer connection.Close()
RRQSessions := map[string]*goftp.RRQSession{}
var RRQSessionsMu sync.Mutex
WRQSessions := map[string]*goftp.WRQSession{}
var WRQSessionsMu sync.Mutex
workingBuffer := make([]byte, 1024)
logSessionNumbers := func() {
theLogger.Printf("Currently have %v number of active read sessions and %v number of active write sessions \n", len(RRQSessions), len(WRQSessions))
}
gc := func() {
// Note that I'm using the mutext TryLock functionality because
// we can simply skip over rather than wait for access in this context
// We'll eventually clean out the session
for {
timeout := 10.0 // ten seconds timeout from the connection being closed to when we kill off the session
for k, e := range WRQSessions {
if e.Completed {
if time.Now().Sub(e.ClosedAt).Seconds() > timeout {
if WRQSessionsMu.TryLock() {
delete(WRQSessions, k)
theLogger.Printf("Write request from IP:%v for file %v complete", k, e.Filename)
logSessionNumbers()
WRQSessionsMu.Unlock()
}
}
}
}
for k, e := range RRQSessions {
if e.Completed {
if time.Now().Sub(e.ClosedAt).Seconds() > timeout {
if RRQSessionsMu.TryLock() {
delete(RRQSessions, k)
theLogger.Printf("Read request from IP:%v for file %v complete", k, e.Filename)
logSessionNumbers()
RRQSessionsMu.Unlock()
}
}
}
}
time.Sleep(time.Second)
}
}
go gc()
for {
//First things first, let's clean up any completed sessions
n, addr, err := connection.ReadFromUDP(workingBuffer)
buffer := workingBuffer[0:n] //only pull the data that we actually read.
if err != nil {
theLogger.Println(err)
return
}
d := goftp.DatagramBuffer{
Buffer: buffer,
Offset: 0,
}
// Buffer containing whatever we're going to send back across the wire
var data []byte
dgo, err := goftp.DestructureDatagram(d)
if err != nil {
// we bail
data = goftp.GenerateErrorMessage(err)
_, err = connection.WriteToUDP(data, addr)
continue
}
switch dgo.Opcode {
case goftp.OPCODE_RRQ:
theLogger.Printf("Got Read request from IP:%v%v for file %v", addr.IP, addr.Port, dgo.Filename)
session, errR := goftp.SetupRRQSession(cleanTftpDirectory(), dgo, addr)
if errR != nil {
data = goftp.GenerateErrorMessage(errR)
break
}
data, _ = goftp.GenerateRRQMessage(session)
RRQSessionsMu.Lock()
RRQSessions[SessionKey(addr)] = session
RRQSessionsMu.Unlock()
case goftp.OPCODE_WRQ:
theLogger.Printf("Got Write request from IP:%v:%v for file %v", addr.IP, addr.Port, dgo.Filename)
if !enableWrites {
data = goftp.GenerateErrorMessage(errors.New("Not accepting writes at the moment"))
break
}
session, errW := goftp.SetupWRQSession(cleanTftpDirectory(), dgo, addr)
if errW != nil {
data = goftp.GenerateErrorMessage(errW)
break
}
data, _ = goftp.GenerateWRQMessage(session)
WRQSessionsMu.Lock()
WRQSessions[SessionKey(addr)] = session
WRQSessionsMu.Unlock()
case goftp.OPCODE_DATA: //we've got incoming data
data, _ = goftp.AcknowledgeWRQSession(WRQSessions[SessionKey(addr)], dgo)
case goftp.OPCODE_ACK:
//So we need to see _which_ block this is an acknowledgement for
if errA := goftp.AcknowledgeRRQSession(RRQSessions[SessionKey(addr)], dgo); errA != nil {
data = goftp.GenerateErrorMessage(errA)
break
}
if !RRQSessions[SessionKey(addr)].Completed {
data, _ = goftp.GenerateRRQMessage(RRQSessions[SessionKey(addr)])
}
default:
data = goftp.GenerateErrorMessage(goftp.GenerateTFTPError(goftp.NOT_DEFINED, "Only able to send you files right now"))
}
if len(data) > 0 {
_, err = connection.WriteToUDP(data, addr)
}
if err != nil {
theLogger.Println(err)
return
}
}
}
func readFlags() {
flag.StringVar(&tftpDirectory, "d", "./files/", "Directory to read/write files to")
flag.IntVar(&port, "p", 6999, "Port to run TFTP server on")
flag.BoolVar(&enableWrites, "w", false, "Allow users to write to the server (potentially unsafe)")
flag.Parse()
}
func cleanTftpDirectory() string {
cleanTftpDirectory, err := filepath.Abs(tftpDirectory)
if err != nil {
log.Default().Fatal(err)
}
return cleanTftpDirectory
}