diff --git a/pymacro/pymacro b/pymacro/pymacro index e0f61c3..3820273 100755 --- a/pymacro/pymacro +++ b/pymacro/pymacro @@ -18,11 +18,11 @@ def get_args(): """ Get command line arguments """ import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-m", "--macros", default=["macros"], action="append", - help="Extra files where macros are stored") - parser.add_argument("-i", "--input", help="The file to be processed", default="-") - parser.add_argument("-o", "--output", help="The location of the output", default="-") + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("-m", "--macros-file", default="macros", + help="File where macros are stored") + parser.add_argument("-i", "--input", help="File to be processed.", default="-") + parser.add_argument("-o", "--output", help="Path of output", default="-") return parser.parse_args() @@ -86,7 +86,10 @@ def upper_check(token, word): def process(input, macros): tokens, otf_macros = tokenize(input) output = tokens - macros = otf_macros + macros + macros = macros + + for key in otf_macros.keys(): + macros[key] = otf_macros[key] for line_number, line in enumerate(tokens): for token_number, token in enumerate(line): @@ -110,13 +113,10 @@ def process(input, macros): # will not be changed value = token - for macro in macros: - if macro[0].lower() == token.lower(): - value = macro[1] - break - elif macro[0].lower() + 's' == token.lower(): - value = pluralize(macro[1], macro=macro) - break + if token.lower() in macros.keys(): + value = macros[token] + elif f'{token.lower()}s' in macros.keys(): + value = pluralize(macro[1], macro=macro) output[line_number][token_number] = upper_check(token, value) output[line_number][token_number] += end @@ -131,7 +131,7 @@ def tokenize(input): """ tokens = [x.split(' ') for x in input.split('\n')] - otf_macros = [] + otf_macros = {} in_otf_macro = False tmp_macro_keyword = None tmp_macro_definition = [] @@ -151,7 +151,7 @@ def tokenize(input): split_token = re.split(r',.|.,', token) tmp_macro_definition.append(split_token[0]) tokens[line_index][token_index] = tmp_macro_keyword + split_token[1] - otf_macros.append((tmp_macro_keyword, ' '.join(tmp_macro_definition))) + otf_macros[tmp_macro_keyword] = ' '.join(tmp_macro_definition) in_otf_macro = False continue elif in_otf_macro: @@ -193,18 +193,22 @@ def get_macros(input): """ Turn a macros string into a list of tuples of macros """ - response = [] + response = {} - # turn input into unvalidated list of macros + # turn input into list of tuples macros = [re.split('[\t]', x) for x in input.split('\n')] - # validate macros + # check if keyword is `source`, get macros from sourced file if it is for index, macro in enumerate(macros): if macro[0] == "source": with open(macro[1]) as file: - response += get_macros(file.read()) + macros += get_macros(file.read()) + macros[index] = () + + # store macros as dict and return + for index, macro in enumerate(macros): if len(macro) >= 2: - response.append(tuple(macros[index])) + response[macro[0].lower()] = macro[1] return response @@ -237,10 +241,8 @@ def main(args): # get macros - macros = [] - for macro_file in args.macros: - with open(macro_file) as file: - macros += get_macros(file.read()) + with open(args.macros_file) as file: + macros = get_macros(file.read()) # get tokens (file contents) if args.input == "-":