Commit 7ea7a7c1 authored by Hanchuan Wu's avatar Hanchuan Wu
Browse files

[bin][util][common] Make util a package and write basic functions into

parent c90ef226
import os
import re
import sys
import fnmatch
import functools
import subprocess
import traceback
import datetime
import shlex
"warning": "\033[1;31m",
"highlight": "\033[1;34m",
"reset": "\033[0m",
"none": "",
def styledBotPrint(s, style="none", **kwargs):
sys.stdout.write("\n🤖 ")
print(s, **kwargs)
def addPrefix(prefix, text, separator=' '):
return prefix + separator + text
def addPrefixToLines(prefix, text, separator=' '):
return '\n'.join(
addPrefix(prefix, line, separator) for line in text.split('\n')
def escapeCharacter(text, character, escCharacter="\\"):
return text.replace(character, f'{escCharacter}{character}')
def escapeCharacters(text, characters, escCharacter="\\"):
for char in characters:
text = escapeCharacter(text, char, escCharacter)
return text
def indent(text, indentation=' '):
text = text.split('\n')
text = [indentation + line for line in text]
return '\n'.join(text)
def makeTable(dictList, config=None, padding=2):
if config is None:
config = {key: key for d in dictList for key in d}
def getColWidth(row):
return max(len(str(r)) for r in row) + padding*2
def getCol(key):
return [config[key]] + [d.get(key, "") for d in dictList]
widths = {key: getColWidth(getCol(key)) for key in config}
def makeRow(rowValues):
row = "|"
for key in config:
row += "{}|".format(rowValues.get(key, "").center(widths[key]))
return row
table = [makeRow({key: config[key] for key in config})]
table.append('|' + '|'.join('-'*widths[key] for key in config) + '|')
table.extend(makeRow(row) for row in dictList)
return '\n'.join(table)
def getCommandErrorHints(command):
......@@ -13,12 +82,13 @@ def getCommandErrorHints(command):
return None
# execute a command and retrieve the output
def runCommand(command, suppressTraceBack=False, errorMessage=''):
def runCommand(command, check=True, suppressTraceBack=False, errorMessage=''):
"""execute a command and retrieve the output"""
shell=True, check=True,
text=True, capture_output=True).stdout
shlex.split(command), check=check, text=True, capture_output=True
except Exception:
eType, eValue, eTraceback = sys.exc_info()
if suppressTraceBack:
......@@ -35,8 +105,9 @@ def runCommand(command, suppressTraceBack=False, errorMessage=''):
# decorator to call function from within the given path
def callFromPath(path):
"""decorator to call function from within the given path"""
def decorator_callFromPath(callFunc):
def wrapper_callFromPath(*args, **kwargs):
......@@ -47,3 +118,245 @@ def callFromPath(path):
return result
return wrapper_callFromPath
return decorator_callFromPath
def userQuery(query, choices=None):
"""query something from the user"""
choicesString = ', '.join(str(c) for c in choices) if choices else ''
querySuffix = f" (choices: {choicesString})\n" if choices else ' '
while True:
styledBotPrint(f"{query.strip()}{querySuffix}", style="highlight")
inp = input()
if choices and inp not in choices:
f"Invalid answer: '{inp}'. Choose from {choicesString}.",
return inp
def queryYesNo(question, default="yes"):
"""query a yes/no answer from the user"""
affirmative = ["yes", "y", "ye"]
negative = ["no", "n"]
def getChoices():
return ", ".join(c for c in affirmative + negative)
def isAffirmative(choice): return choice in affirmative
def isNegative(choice): return choice in negative
def isValid(choice): return isAffirmative(choice) or isNegative(choice)
if default is not None and not isValid(default):
raise ValueError("\nInvalid default answer: '{}', choices: '{}'\n"
.format(default, getChoices()))
if default is None:
prompt = " [y/n] "
prompt = " [Y/n] " if isAffirmative(default) else " [y/N] "
while True:
styledBotPrint(f"{question.strip()}{prompt}", style="highlight", end="")
choice = input().lower()
if default is not None and choice == "":
return True if isAffirmative(default) else False
if not isValid(choice):
f"Invalid answer: '{choice}'. Choose from '{getChoices()}'",
return True if isAffirmative(choice) else False
def cppHeaderFilter():
return lambda fileName: fileName == 'config.h'
def includedCppProjectHeaders(file,
"""get all project headers included by a cpp file"""
filePath = os.path.join(projectBase, file)
if not os.path.exists(filePath):
raise IOError(f'Cpp file {filePath} does not exist')
with open(filePath, 'r') as f:
content =
headerInBracket = re.findall(r'#include\s+<(.+?)>', content)
headerInQuotation = re.findall(r'#include\s+"(.+?)"', content)
def process(pathInProject):
headerPath = os.path.join(projectBase, pathInProject)
if os.path.exists(headerPath):
if not headerFilter(pathInProject):
if headerPath not in headers:
headerPath, projectBase,
headers, headerFilter
for header in headerInBracket:
for header in headerInQuotation:
absHeaderPath = os.path.join(os.path.dirname(file), header)
projectPath = os.path.relpath(absHeaderPath, projectBase)
return headers
def findMatchingFiles(path, pattern):
"""find all files below the given folder that match the given pattern"""
result = []
for root, dirs, files in os.walk(path):
relativeRootPath = os.path.relpath(root, path)
for file in files:
if fnmatch.fnmatch(file, pattern):
result.append(os.path.join(relativeRootPath, file))
return result
def isGitRepository(pathToRepo='.'):
run = callFromPath(pathToRepo)(runCommand)
run('git status')
return True
except Exception:
return False
def getRemote(pathToRepo='.'):
run = callFromPath(pathToRepo)(runCommand)
return run('git ls-remote --get-url').strip('\n')
def fetchRepo(remote, pathToRepo='.'):
run = callFromPath(pathToRepo)(runCommand)
run('git fetch {}'.format(remote))
def hasUntrackedFiles(pathToRepo='.'):
run = callFromPath(pathToRepo)(runCommand)
return run('git ls-files --others --exclude-standard') != ''
def isPersistentBranch(branchName):
if branchName == 'origin/master':
return True
if branchName.startswith('origin/releases/'):
return True
return False
# get the most recent commit that also exists on remote master/release branch
# may be used to find a commit we can use as basis for a pub module
def mostRecentCommonCommitWithRemote(modFolderPath,
run = callFromPath(modFolderPath)(runCommand)
def findBranches(sha):
candidates = run('git branch -r --contains {}'.format(sha)).split('\n')
candidates = [branch.strip().split(' ->')[0] for branch in candidates]
return list(filter(branchFilter, candidates))
revList = run('git rev-list HEAD').split('\n')
for rev in revList:
branches = findBranches(rev)
if branches:
return branches[0], rev
raise RuntimeError('Could not find suitable ancestor commit'
' on a branch that matches the given filter')
# function to extract persistent, remotely available git versions for all
def getPersistentVersions(modFolderPaths, ignoreUntracked=False):
result = {}
for modFolderPath in modFolderPaths:
if not isGitRepository(modFolderPath):
raise Exception('Folder is not a git repository')
if hasUntrackedFiles(modFolderPath) and not ignoreUntracked:
raise Exception(
"Found untracked files in '{}'. "
"Please commit, stash, or remove them. Alternatively, if you "
"are sure they are not needed set ignoreUntracked=True"
result[modFolderPath] = {}
result[modFolderPath]['remote'] = getRemote(modFolderPath)
# update remote to make sure we find all upstream commits
fetchRepo(result[modFolderPath]['remote'], modFolderPath)
branch, rev = mostRecentCommonCommitWithRemote(modFolderPath)
run = callFromPath(modFolderPath)(runCommand)
result[modFolderPath]['revision'] = rev
result[modFolderPath]['date'] = run(
'git log -n 1 --format=%ai {}'.format(rev)
result[modFolderPath]['author'] = run(
'git log -n 1 --format=%an {}'.format(rev)
# this may return HEAD if we are on some detached HEAD tree
result[modFolderPath]['branch'] = branch
return result
def getPatches(persistentVersions):
result = {}
for path, gitInfo in persistentVersions.items():
run = callFromPath(path)(runCommand)
uncommittedPatch = run('git diff')
unpublishedPatch = run(
'git format-patch --stdout {}'.format(gitInfo['revision'])
untrackedPatch = ''
untrackedFiles = run('git ls-files --others --exclude-standard')
binaryExtension = (
'.png', '.gif', '.jpg', '.tiff', '.bmp', '.DS_Store', '.eot', '.otf', '.ttf', '.woff', '.rgb', '.pdf',
if untrackedFiles:
for file in untrackedFiles.splitlines():
if not str(file).endswith(binaryExtension):
untrackedPatch += run('git --no-pager diff /dev/null {}'.format(file), check=False)
result[path] = {}
result[path]['untracked'] = untrackedPatch if untrackedPatch else None
result[path]['unpublished'] = unpublishedPatch if unpublishedPatch else None
result[path]['uncommitted'] = uncommittedPatch if uncommittedPatch else None
return result
def versionTable(versions,
'name': 'module name',
'branch': 'branch name',
'revision': 'commit sha',
'date': 'commit date'
return makeTable(versions, config)
def printVersionTable(versions):
#!/usr/bin/env python3
import os
import argparse
from common import runCommand
from common import callFromPath
# print warning message for scanned folders that are not git repositories
def printNoGitRepoWarning(folderPath):
print("Folder " + folderPath + " does not seem to be the top level" \
"of a git repository and will be skipped. Make sure not to call " \
"this script from a sub-directory of a git repository.")
# raise error due to untracked files present in the given module folder
def raiseUntrackedFilesError(folderPath):
raise RuntimeError('Found untracked files in module folder: "' + folderPath + '". ' \
'Please commit, stash, or remove them.')
# returns true if the given folder is a git repository
def isGitRepository(modFolderPath):
return os.path.exists(os.path.join(modFolderPath, '.git'))
# returns true if a module contains untracked files
def hasUntrackedFiles(modFolderPath):
run = callFromPath(modFolderPath)(runCommand)
return run('git ls-files --others --exclude-standard') != ''
# function to extract git version information for modules
# returns a dictionary containing module information for each given module folder
def getUsedVersions(modFolderPaths, ignoreUntracked=False):
result = {}
for modFolderPath in modFolderPaths:
# make sure this is the top level of a git repository
if not isGitRepository(modFolderPath):
if not ignoreUntracked and hasUntrackedFiles(modFolderPath):
run = callFromPath(modFolderPath)(runCommand)
result[modFolderPath] = {}
result[modFolderPath]['remote'] = run('git ls-remote --get-url').strip('\n')
result[modFolderPath]['revision'] = run('git log -n 1 --format=%H @{upstream}').strip('\n')
result[modFolderPath]['date'] = run('git log -n 1 --format=%ai @{upstream}').strip('\n')
result[modFolderPath]['author'] = run('git log -n 1 --format=%an @{upstream}').strip('\n')
result[modFolderPath]['branch'] = run('git rev-parse --abbrev-ref HEAD').strip('\n')
return result
# create patches for unpublished commits and uncommitted changes in modules
def getPatches(modFolderPaths, ignoreUntracked=False):
result = {}
for modFolderPath in modFolderPaths:
# make sure this is the top level of a git repository
if not isGitRepository(modFolderPath):
if not ignoreUntracked and hasUntrackedFiles(modFolderPath):
run = callFromPath(modFolderPath)(runCommand)
unpubPatch = run('git format-patch --stdout @{upstream}')
unCommPatch = run('git diff')
if unpubPatch != '' or unCommPatch != '': result[modFolderPath] = {}
if unpubPatch != '': result[modFolderPath]['unpublished'] = unpubPatch
if unCommPatch != '': result[modFolderPath]['uncommitted'] = unCommPatch
return result
# prints the detected versions as table
def printVersionTable(versions):
print("\t| {:^50} | {:^50} | {:^50} | {:^30} |".format('module folder', 'branch', 'commit hash', 'commit date'))
print("\t" + 193*'-')
for folder, versionInfo in versions.items():
print("\t| {:^50} | {:^50} | {:^50} | {:^30} |".format(folder, versionInfo['branch'], versionInfo['revision'], versionInfo['date']))
# For standalone execution
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='This script extracts the used dune/dumux versions.')
parser.add_argument('-p', '--path', required=False, help='the path to the top folder containing your dune/dumux modules')
parser.add_argument('-i', '--ignoreuntracked', required=False, action='store_true', help='use this flag to ignore untracked files present in the modules')
parser.add_argument('-s', '--skipfolders', required=False, nargs='*', help='a list of module folders to be skipped')
cmdArgs = vars(parser.parse_args())
modulesPath = os.getcwd() if not cmdArgs['path'] else os.path.join(os.getcwd(), cmdArgs['path'])
print('\nDetermining the versions of all dune modules in the folder: ' + modulesPath)
def getPath(modFolder):
return os.path.join(modulesPath, modFolder)
modFolderPaths = [getPath(dir) for dir in os.listdir(modulesPath) if os.path.isdir(getPath(dir))]
if cmdArgs['skipfolders']:
cmdArgs['skipfolders'] = [f.strip('/') for f in cmdArgs['skipfolders']]
modFolderPaths = [d for d in modFolderPaths if os.path.basename(d.strip('/')) not in cmdArgs['skipfolders']]
versions = getUsedVersions(modFolderPaths, True)
print("\nDetected the following versions:")
# maybe check untracked files
if not cmdArgs['ignoreuntracked']:
modsWithUntracked = [f for f in versions if hasUntrackedFiles(f)]
if modsWithUntracked:
print('WARNING: Found untracked files in the following modules:\n\n')
print('\nPlease make sure that these are not required for your purposes.')
print('If not, you can run this script with the option -i/--ignoreuntracked to suppress this warning.')
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment