Coverage for src/hods/utils.py: 100.00%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

194 statements  

1"""hods - home directory synchronization. 

2 

3Copyright (C) 2016-2020 Mathias Stelzer <knoppo@rolln.de> 

4 

5hods is free software: you can redistribute it and/or modify 

6it under the terms of the GNU General Public License as published by 

7the Free Software Foundation, either version 3 of the License, or 

8(at your option) any later version. 

9 

10hods is distributed in the hope that it will be useful, 

11but WITHOUT ANY WARRANTY; without even the implied warranty of 

12MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

13GNU General Public License for more details. 

14 

15You should have received a copy of the GNU General Public License 

16along with this program. If not, see <http://www.gnu.org/licenses/>. 

17""" 

18import logging 

19import os 

20import platform 

21import pwd 

22import re 

23import subprocess 

24from collections import namedtuple 

25 

26logger = logging.getLogger(__name__) 

27 

28 

29def is_in_path(executable): 

30 """Look for the given executable in PATH and return a bool whether found.""" 

31 paths = os.getenv('PATH', '').split(os.pathsep) 

32 for path in paths: 

33 path = path.strip() 

34 if not path: 

35 continue 

36 filename = os.path.join(path, executable) 

37 if os.path.isfile(filename) and os.access(filename, os.X_OK): 

38 return True 

39 return False 

40 

41 

42try: 

43 from shutil import which 

44except ImportError: # pragma: no cover 

45 try: 

46 from distutils.spawn import find_executable as which 

47 except ImportError: 

48 which = is_in_path 

49 

50 

51def get_hostname(): 

52 """Retrieve the hostname.""" 

53 computer_name = os.getenv('COMPUTERNAME', platform.node()) 

54 return os.getenv('HOSTNAME', computer_name).split('.')[0] 

55 

56 

57# Don't trust the environment with the current user! 

58# Get the user and home directory from passwd using the effective user id to 

59# allow execution in a different environment. This makes it easier to use 

60# hods with ssh-disabled remote root users. 

61def pw(): 

62 """Return the pwd namedtuple for the current user.""" 

63 return pwd.getpwuid(os.geteuid()) 

64 

65 

66def pw_user(): 

67 """Return the username of the current user.""" 

68 return pw().pw_name 

69 

70 

71def pw_home(): 

72 """Return the home directory of the current user.""" 

73 return pw().pw_dir 

74 

75 

76def run(*cmd, check=True, capture_output=True, hide=False, **kwargs): 

77 """Wrapper for subprocess.run with extensive logging. 

78 

79 Args: 

80 cmd: 

81 Command arguments to run. 

82 check: 

83 raise `subprocess.CalledProcessError` if command exits with non-zero return code. 

84 capture_output: 

85 Pipe and store stdout/stderr. stderr is piped to stdout and will always be empty. 

86 hide: 

87 Do not forward subprocess stdout. Does not affect logging. Ignored if 

88 ``capture_output`` is `False`. 

89 **kwargs: 

90 Pass to subprocess function. 

91 

92 Return: 

93 `subprocess.CompletedProcess` 

94 

95 Raises: 

96 `subprocess.CalledProcessError` if check is `True` 

97 """ 

98 kwargs.setdefault('universal_newlines', True) 

99 

100 if capture_output: 

101 kwargs['stdout'] = subprocess.PIPE 

102 kwargs['stderr'] = subprocess.STDOUT 

103 

104 shell_cmd = subprocess.list2cmdline(cmd) 

105 logger.info(shell_cmd) 

106 

107 stdout = None 

108 with subprocess.Popen(cmd, **kwargs) as process: 

109 if capture_output: 

110 stdout = '' 

111 for line in process.stdout: 

112 stdout += line 

113 line = line.rstrip() 

114 logger.info('%s: %s', cmd[0], line) 

115 if not hide: 

116 print(line) 

117 process.stdout.close() 

118 retcode = process.wait() 

119 else: 

120 process.communicate() 

121 retcode = process.poll() 

122 

123 if check and retcode: 

124 logger.exception('subprocess error:%s\n%s', shell_cmd, stdout) 

125 raise subprocess.CalledProcessError(retcode, process.args, stdout) 

126 

127 logger.debug('subprocess finished successfully') 

128 return subprocess.CompletedProcess(process.args, retcode, stdout) 

129 

130 

131class ProcessError(Exception): 

132 """Base class for all command errors.""" 

133 

134 

135class SSHError(ProcessError): 

136 """Base class for all ssh errors.""" 

137 

138 pass 

139 

140 

141class RSyncError(ProcessError): 

142 """Base class for all rsync errors.""" 

143 

144 pass 

145 

146 

147class GitError(ProcessError): 

148 """Base class for all git errors.""" 

149 

150 pass 

151 

152 

153def run_ssh(server, *cmd, **kwargs): 

154 """Run ssh subprocess with given arguements.""" 

155 try: 

156 return run('ssh', server, *cmd, **kwargs) 

157 except FileNotFoundError: 

158 raise SSHError('ssh is not installed') 

159 

160 

161def run_rsync(src, dst, **kwargs): 

162 """Run rsync subprocess to synchronize the given paths. 

163 

164 :param src: Command arguments to run. 

165 :param dst: Command arguments to run. 

166 :param kwargs: Pass to subprocess function. 

167 :return: Command status 

168 """ 

169 try: 

170 return run('rsync', '-ave', 'ssh', src, dst, **kwargs) 

171 except FileNotFoundError: 

172 raise RSyncError('rsync is not installed') 

173 

174 

175def run_git(*cmd, **kwargs): 

176 """Run git subprocess with given arguements.""" 

177 try: 

178 return run('git', *cmd, **kwargs) 

179 except FileNotFoundError: 

180 raise GitError('git is not installed') 

181 

182 

183def format_kwargs(*args, **kwargs): 

184 """Format and return the given arguments as string.""" 

185 formatted = [str(a) for a in args] 

186 formatted += ['{}={}'.format(k, v) for k, v in kwargs.items()] 

187 return ', '.join(formatted) 

188 

189 

190def clean_server(server): 

191 """Extract and clean the host from the given server and make sure it ends with ":".""" 

192 if server is None: 

193 return 

194 server = server.strip() 

195 if not server: 

196 return 

197 if '@' in server: 

198 user, serv = server.split('@', 1) 

199 if not user: 

200 raise ValueError('Server address contains "@" but user is empty') 

201 if not serv: 

202 raise ValueError('Server address contains "@" but server is empty') 

203 if ':' in server: 

204 server = server.split(':')[0] 

205 return server 

206 

207 

208class Sortable: 

209 """Mixin to add move_up and move_down methods to a child class.""" 

210 

211 def _get_sortable_items(self): 

212 """Return the sortable list containing this item.""" 

213 return self.parent.children 

214 

215 def _move(self, up=True): 

216 """Move the item one position up or down if possible.""" 

217 items = self._get_sortable_items() 

218 if self not in items: 

219 return False 

220 index = items.index(self) 

221 

222 if up: 

223 index -= 1 

224 if index == -1: 

225 return False # end of list, abort 

226 else: 

227 index += 1 

228 if index == len(items): 

229 return False # end of list, abort 

230 

231 items.remove(self) 

232 items.insert(index, self) 

233 return True 

234 

235 def move_up(self): 

236 """Move the item one position up if possible.""" 

237 return self._move(up=True) 

238 

239 def move_down(self): 

240 """Move the item one position down if possible.""" 

241 return self._move(up=False) 

242 

243 

244class SSHAgentConnectionError(Exception): 

245 """Exception thrown when connecting to the ssh agent fails.""" 

246 

247 def __init__(self, *args, **kwargs): 

248 """Initialize exception.""" 

249 msg = 'Failed to connect to ssh agent. Is it running?' 

250 super().__init__(msg, *args, **kwargs) 

251 

252 

253SSHAgentKey = namedtuple('SSHAgentKey', ('length', 'algorithm', 'key', 'path', 'type')) 

254 

255 

256class SSHAgent: 

257 """The ssh-agent in the current environment.""" 

258 

259 def __init__(self): 

260 """Initialize ssh agent.""" 

261 self.started = False 

262 

263 @property 

264 def auth_sock(self): 

265 """The SSH_AUTH_SOCK environment variable.""" 

266 return os.environ.get('SSH_AUTH_SOCK', None) 

267 

268 @auth_sock.setter 

269 def auth_sock(self, value): 

270 """The SSH_AUTH_SOCK environment variable.""" 

271 if value is None: 

272 try: 

273 del os.environ['SSH_AUTH_SOCK'] 

274 except KeyError: 

275 pass 

276 return 

277 os.environ['SSH_AUTH_SOCK'] = value 

278 

279 def start(self): 

280 """Start the agent and set environment variables.""" 

281 if self.is_running(): 

282 return False 

283 

284 output = subprocess.check_output(['ssh-agent'], stderr=subprocess.PIPE).decode() 

285 

286 m = re.search('SSH_AUTH_SOCK=(?P<auth_sock>[^;]+);', output, re.DOTALL) 

287 if m is None: 

288 raise ValueError('SSH_AUTH_SOCK not found in ssh-agent output: ' + output) 

289 

290 self.auth_sock = m.group('auth_sock') 

291 self.started = True 

292 return True 

293 

294 def is_running(self): 

295 """Check whether the agent is running.""" 

296 if not self.auth_sock: 

297 return False 

298 try: 

299 list(self.gen()) 

300 except SSHAgentConnectionError: 

301 return False 

302 return True 

303 

304 def kill(self): 

305 """Kill the agent and remove environment variables.""" 

306 subprocess.call(['ssh-agent', '-k'], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 

307 self.auth_sock = None 

308 

309 def gen(self): 

310 """Generate active ssh agent keys.""" 

311 cmd = ['ssh-add', '-l'] 

312 try: 

313 output = subprocess.check_output(cmd, stderr=subprocess.PIPE) 

314 except subprocess.CalledProcessError as e: 

315 if e.returncode == 1: 

316 return 

317 if e.returncode == 2: 

318 raise SSHAgentConnectionError() 

319 raise 

320 

321 for line in output.decode().splitlines(): 

322 line = line.strip() 

323 if not line: 

324 continue 

325 pattern = re.compile( 

326 r'^(?P<length>\d+) ' 

327 r'(?P<algorithm>[A-Z0-9]+):(?P<key>[a-zA-Z0-9+]+) ' 

328 r'(?P<path>(/.*)+) ' 

329 r'\((?P<type>[a-zA-Z0-9]+)\)$' # noqa: C812 

330 ) 

331 m = re.match(pattern, line) 

332 if m is None: 

333 raise ValueError('Invalid line in ssh-add output: {}'.format(line)) 

334 yield SSHAgentKey(**m.groupdict()) 

335 

336 def has(self, key=None): 

337 """Check whether the ssh agent has the given or any key. 

338 

339 :param key: Path to the private key. Checks for any key if none is given. 

340 :type key: `str` or `None` 

341 :return: Whether the key is active or not. 

342 :rtype: `bool` 

343 """ 

344 keys = list(self.gen()) 

345 if key: 

346 return key in keys 

347 return bool(keys) 

348 

349 def add(self, key=None): 

350 """Add a private key to the agent. 

351 

352 :param key: Path to the private key. The default key is used if none is given. 

353 :type key: `str` or `None` 

354 :return: Whether the operation was successful. 

355 :rtype: `bool` 

356 """ 

357 cmd = ['ssh-add'] 

358 if key: 

359 cmd.append(key) 

360 try: 

361 exitcode = subprocess.call(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 

362 except subprocess.CalledProcessError as e: 

363 if e.returncode == 2: 

364 raise SSHAgentConnectionError() 

365 raise 

366 return not exitcode