Coverage for tcprocd/clienthandler.py: 98.70%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

239 statements  

1"""tcprocd client handler.""" 

2from __future__ import unicode_literals, print_function, absolute_import 

3from tcprocd.protocol import Protocol, ProtocolError, Disconnect 

4from tcprocd.runner import Runner 

5from tcprocd.user import User 

6import grp 

7import pwd 

8import socket 

9import struct 

10import logging 

11try: 

12 import socketserver 

13except ImportError: 

14 import SocketServer as socketserver 

15 

16 

17logger = logging.getLogger(__name__) 

18 

19 

20class ClientHandler(socketserver.StreamRequestHandler): 

21 """ 

22 Started by ``TCPServer`` for every client. 

23 

24 :param request: 

25 :param client_address: 

26 :param server: 

27 """ 

28 

29 def __init__(self, request, client_address, server): 

30 """Initialize client handler.""" 

31 #: username of the connected user 

32 self.user = None 

33 

34 self.as_user = server.user 

35 self.as_group = server.group 

36 

37 self.protocol = Protocol(request) 

38 

39 socketserver.StreamRequestHandler.__init__(self, request, client_address, server) 

40 

41 def send_version(self): 

42 """Send the version to the client.""" 

43 self.protocol.send_part(3, self.server.version) 

44 

45 def _authenticate_unix_domain(self): 

46 # socket.SO_PEERCRED requires python > 3.x 

47 try: 

48 SO_PEERCRED = socket.SO_PEERCRED 

49 except AttributeError: 

50 SO_PEERCRED = 17 

51 

52 creds = self.request.getsockopt(socket.SOL_SOCKET, SO_PEERCRED, struct.calcsize('3i')) 

53 pid, uid, gid = struct.unpack('3i', creds) 

54 username = pwd.getpwuid(uid).pw_name 

55 groupname = grp.getgrgid(gid).gr_name 

56 

57 logger.debug('authentication: username: "{}"'.format(username)) 

58 

59 user = self.server.get_user(username) 

60 

61 # unix socket allows connection without password since only 

62 # local users can connect. Also the started processes are 

63 # actually run as the connected user. 

64 if user is None: 

65 logger.info('authentication: adding local user "{}" to {}'.format(username, self.server.config_file)) 

66 user = User(username=username, password=None, admin=False) 

67 self.server.users.append(user) 

68 self.server.write_config() 

69 

70 elif user.password is not None: 

71 logger.info('authentication ({})'.format(user.username)) 

72 self.protocol.send_part(2, Protocol.AUTHENTICATION_REQUIRED) 

73 

74 logger.debug('authentication ({}): receiving password...'.format(user.username)) 

75 

76 self.protocol.set_timeout(30) 

77 password = self.protocol.recv_part(2) 

78 self.protocol.set_timeout() 

79 

80 logger.debug('authentication ({}): received password'.format(user.username)) 

81 

82 if not user.check_password(password): 

83 logger.warning('authentication ({}): password mismatch!'.format(user.username)) 

84 return 

85 

86 self.user = user 

87 

88 if self.user.user: 

89 self.as_user = self.user.user 

90 else: 

91 self.as_user = username 

92 logger.debug('authentication ({}): system user "{}"'.format(self.user.username, self.as_user)) 

93 

94 if self.user.group: 

95 self.as_group = self.user.group 

96 else: 

97 self.as_group = groupname 

98 logger.debug('authentication ({}): system group "{}"'.format(self.user.username, self.as_group)) 

99 

100 logger.info('authentication ({}): success'.format(self.user.username)) 

101 

102 def _authenticate_tcp(self): 

103 logger.info('authentication') 

104 self.protocol.send_part(2, Protocol.AUTHENTICATION_REQUIRED) 

105 

106 logger.debug('authentication: receiving credentials...') 

107 username = self.protocol.recv_part(2) 

108 password = self.protocol.recv_part(2) 

109 

110 logger.debug('authentication: received credentials') 

111 

112 user = self.server.get_user(username) 

113 

114 if user is None: 

115 logger.warning('authentication: Unknown user: "{}"'.format(username)) 

116 return 

117 

118 if not user.check_password(password): 

119 logger.warning('authentication ({}): password mismatch!'.format(username)) 

120 return 

121 

122 self.user = user 

123 

124 if self.user.user: 

125 self.as_user = self.user.user 

126 else: 

127 self.as_user = self.server.user 

128 logger.debug('authentication ({}): system user "{}"'.format(self.user.username, self.as_user)) 

129 

130 if self.user.group: 

131 self.as_group = self.user.group 

132 else: 

133 self.as_group = self.server.group 

134 logger.debug('authentication ({}): system group "{}"'.format(self.user.username, self.as_group)) 

135 

136 logger.info('authentication ({}): success'.format(self.user.username)) 

137 

138 def authenticate(self): 

139 """Authenticate, set the username and return the status. 

140 

141 :return: :class:`bool` - Whether the authentication succeeded. 

142 """ 

143 if self.server.is_unix_domain: 

144 logger.debug('authentication: unix domain') 

145 self._authenticate_unix_domain() 

146 else: 

147 logger.debug('authentication: tcp') 

148 self._authenticate_tcp() 

149 

150 if self.user is None: 

151 self.protocol.send_part(2, Protocol.AUTHENTICATION_ERROR) 

152 return False 

153 

154 self.protocol.send_part(2, Protocol.OK) 

155 return True 

156 

157 def handle_command(self): 

158 """Wait for a command and handle it.""" 

159 command = self.protocol.recv_part(2) 

160 logger.debug(self.logf('command: "{}"'.format(command))) 

161 self.exec_command(command) 

162 

163 def handle(self): 

164 """Wait for messages until connection closes. 

165 

166 Called when the thread starts. 

167 

168 Append the handler to ``handlers`` of the assigned server and 

169 send the version information to the client. 

170 authenticate, set username and call ``_message_router`` for 

171 each message. 

172 """ 

173 try: 

174 self.server.handlers.append(self) 

175 self.send_version() 

176 logger.info('connected.') 

177 

178 if not self.authenticate(): 

179 logger.info('authentication failed!') 

180 return 

181 

182 logger.info(self.logf('authenticated.')) 

183 

184 while True: 

185 try: 

186 self.handle_command() 

187 except socket.error as e: 

188 # if e.errno != 104: # Connection reset by peer 

189 logger.exception(e) 

190 break 

191 except ProtocolError as e: 

192 self.protocol.send_part(2, Protocol.ERROR) 

193 logger.exception(e) 

194 break 

195 except Disconnect: 

196 logger.info(self.logf('disconnected.')) 

197 break 

198 except Exception as e: 

199 # Catch all other exceptions, too. But log them. 

200 self.protocol.send_part(2, Protocol.ERROR) 

201 logger.exception(e) 

202 break 

203 

204 finally: 

205 self.connection.close() 

206 self.server.handlers.remove(self) 

207 

208 def exec_command(self, command): 

209 """Execute the given line starting with the command.""" 

210 method = 'do_{}'.format(command) 

211 if not hasattr(self, method): 

212 # TODO: ProtocolError('Unknown Command', extra={"code": 1, "command": command}) 

213 logger.warning(self.logf('ProtocolError: Unknown command: "{}"'.format(command))) 

214 raise ProtocolError 

215 logger.debug('running ClientHandler.{}'.format(method)) 

216 return getattr(self, method)() 

217 

218 def do_list(self): 

219 r"""Return a list of all running servers delimited by '\n'.""" 

220 if not self.server.runners: 

221 logger.debug(self.logf(Protocol.OFFLINE)) 

222 self.protocol.send_part(2, Protocol.OFFLINE) 

223 return 

224 

225 if self.user.admin: 

226 runner_names = list(self.server.runners.keys()) 

227 else: 

228 runner_names = [ 

229 name for name, runner in self.server.runners.items() 

230 if runner.owner == self.user 

231 ] 

232 runner_names.sort() 

233 part = '\n'.join(runner_names) 

234 logger.debug(self.logf(part)) 

235 self.protocol.send_part(2, Protocol.OK) 

236 self.protocol.send_part(6, part) 

237 

238 def do_start(self): 

239 """Start the given name.""" 

240 name = self.protocol.recv_part(2) 

241 command = self.protocol.recv_part(3) 

242 path = self.protocol.recv_part(3) 

243 

244 if not path: 

245 path = None 

246 

247 if name in self.server.runners: 

248 logger.debug(self.logf(Protocol.EXISTS)) 

249 self.protocol.send_part(2, Protocol.EXISTS) 

250 return 

251 

252 r = Runner(name, 

253 command, 

254 owner=self.user, 

255 on_exit=self.server.on_runner_exit, 

256 path=path, 

257 as_user=self.as_user, 

258 as_group=self.as_group) 

259 

260 # self.server.start_queue.add(r) 

261 

262 r.start() 

263 self.server.runners[name] = r 

264 logger.info(self.logf('started')) 

265 self.protocol.send_part(2, Protocol.OK) 

266 

267 def do_kill(self): 

268 """Kill the given process.""" 

269 name = self.protocol.recv_part(2) 

270 if name == '': 

271 if not self.server.runners: 

272 logger.debug(self.logf(Protocol.OFFLINE)) 

273 self.protocol.send_part(2, Protocol.OFFLINE) 

274 return 

275 

276 runners = [ 

277 runner for runner in self.server.runners.values() 

278 if runner.owner == self.user 

279 ] 

280 

281 else: 

282 try: 

283 runner = self.server.runners[name] 

284 except KeyError: 

285 logger.debug(self.logf(Protocol.OFFLINE)) 

286 self.protocol.send_part(2, Protocol.OFFLINE) 

287 return 

288 

289 if runner.owner != self.user: 

290 logger.debug(self.logf(Protocol.PERMISSION_DENIED)) 

291 self.protocol.send_part(2, Protocol.PERMISSION_DENIED) 

292 return 

293 

294 runners = [runner] 

295 

296 for runner in runners: 

297 logger.info(self.logf('Killed ' + runner.name)) 

298 runner.kill() 

299 

300 self.protocol.send_part(2, Protocol.OK) 

301 

302 def do_cat(self): 

303 """Return lines of stdout.""" 

304 name = self.protocol.recv_part(2) 

305 start = self.protocol.recv_part(1) 

306 

307 try: 

308 start = int(start) 

309 except ValueError: 

310 raise ProtocolError('start is not a number') 

311 

312 runner = self.get_runner(name) 

313 if runner is None: 

314 return 

315 

316 part = '\n'.join(runner.history[start:]) 

317 

318 if not part.strip(): 

319 logger.debug('no output') 

320 else: 

321 msg = 'Output: {}{}'.format(part[:20], '...' if len(part) > 20 else '') 

322 logger.debug(self.logf(msg)) 

323 self.protocol.send_part(2, Protocol.OK) 

324 self.protocol.send_part(6, part) 

325 

326 def do_command(self): 

327 """Write given message to runners stdin.""" 

328 name = self.protocol.recv_part(2) 

329 command = self.protocol.recv_part(3) 

330 

331 runner = self.get_runner(name) 

332 if runner is None: 

333 logger.info(self.logf('Unknown runner "{}"'.format(name))) 

334 return 

335 

336 runner.write_line(command) 

337 logger.debug(self.logf('Sent command "{}" to runner "{}"'.format(command, runner.name))) 

338 self.protocol.send_part(2, Protocol.OK) 

339 

340 def get_runner(self, name): 

341 """Return the runner for the given name.""" 

342 try: 

343 runner = self.server.runners[name] 

344 except KeyError: 

345 logger.debug(self.logf(Protocol.OFFLINE)) 

346 self.protocol.send_part(2, Protocol.OFFLINE) 

347 return 

348 

349 if runner.owner != self.user: 

350 logger.debug(self.logf(Protocol.PERMISSION_DENIED)) 

351 self.protocol.send_part(2, Protocol.PERMISSION_DENIED) 

352 return 

353 

354 return runner 

355 

356 def do_attach(self): 

357 """Attach to the shell. 

358 

359 Any further command will be sent to the process. 

360 Also this client will receive the console output. 

361 """ 

362 name = self.protocol.recv_part(2) 

363 

364 runner = self.get_runner(name) 

365 if runner is None: 

366 return 

367 

368 self.protocol.send_part(2, Protocol.OK) 

369 

370 runner.attached_handlers.append(self) 

371 self.protocol.set_timeout(None) 

372 

373 try: 

374 

375 while True: 

376 command = self.protocol.readline() 

377 logger.debug(self.logf('Writing: "{}"'.format(command))) 

378 runner.write_line(command) 

379 

380 finally: 

381 runner.attached_handlers.remove(self) 

382 self.protocol.set_timeout() 

383 

384 def logf(self, *args): 

385 """Format args for logging.""" 

386 return '({}) {}'.format(self.user, ' '.join([str(a) for a in args]))