Update fastai union annotations using ast

July 08, 2022

This notebook defines and exports a lightweight command line tool that updates union annotations in notebooks from the fastai tuple style (x:(int,str)) to the Python 3.10 union operator (x:int|str), using the ast standard library, and developed with nbprocess.

#|default_exp __main__
#|export
import ast
import sys
from execnb.nbio import read_nb, write_nb
from fastcore.test import test_eq
#|export
def tuple2bitor(annot):
    "Convert fastai tuple style union annotation to py310 union operator"
    bitor = annot.dims[0]
    for right in annot.dims[1:]: bitor = ast.BinOp(left=bitor, right=right, op=ast.BitOr())
    return bitor

def tuple2bitorstr(annot): return ast.unparse(tuple2bitor(annot)).replace(' ', '')
a = ast.Tuple([ast.Name(id=o) for o in ('int','str','float')])
test_eq(ast.unparse(a),'(int, str, float)')
test_eq(tuple2bitorstr(a),'int|str|float')
#|export
def split_parts(source, node):
    "Split `source` into parts before, containing, and after `node`"
    lines = source.split('\n')
    assert node.lineno == node.end_lineno, 'Multi-line annotations not supported'
    l = node.lineno-1
    line = lines[l]
    s,e = node.col_offset, node.end_col_offset
    return '\n'.join(lines[:l]+[line[:s]]), line[s:e], '\n'.join([line[e:]]+lines[l+1:])
s = '''
def f(
    x: (int, str, float),
    y=5
): pass'''
n = ast.parse(s)
a = n.body[0].args.args[0].annotation
ps = split_parts(s, a)
test_eq(ps, ('\ndef f(\n    x: ', '(int, str, float)', ',\n    y=5\n): pass'))
#|export
def replace_node(source, node, repl):
    "Replace `node` in `source` with `repl`"
    parts = split_parts(source, node)
    return parts[0] + repl + parts[2]
test_eq(replace_node(s, a, tuple2bitorstr(a)), '\ndef f(\n    x: int|str|float,\n    y=5\n): pass')
#|export
def fix_tuple_annots(source):
    "Convert all fastai tuple style union annotations in `source` to py310 union operator"
    while True:
        n = ast.parse(source)
        try: a = next(o.annotation for o in ast.walk(n) if isinstance(getattr(o,'annotation',None),ast.Tuple))
        except StopIteration: return source
        source = replace_node(source, a, tuple2bitorstr(a))
s = '''
@patch
def crop_pad(x:TensorBBox|TensorPoint|Image.Image,
    sz:(int, tuple), # Crop/pad size of input, duplicated if one value is specified
    tl:tuple=None, # Optional top-left coordinate of the crop/pad, if `None` center crop
    orig_sz:tuple=None, # Original size of input
    pad_mode:PadMode=PadMode.Zeros, # Fastai padding mode
    resize_mode=BILINEAR, # Pillow `Image` resize mode
    resize_to:tuple=None # Optional post crop/pad resize of input
):
    if isinstance(sz,int): sz = (sz,sz)
    orig_sz = fastuple(_get_sz(x) if orig_sz is None else orig_sz)
    sz,tl = fastuple(sz),fastuple(((_get_sz(x)-sz)//2) if tl is None else tl)
    return x._do_crop_pad(sz, tl, orig_sz=orig_sz, pad_mode=pad_mode, resize_mode=resize_mode, resize_to=resize_to)
'''

test_eq(fix_tuple_annots(s), '''
@patch
def crop_pad(x:TensorBBox|TensorPoint|Image.Image,
    sz:int|tuple, # Crop/pad size of input, duplicated if one value is specified
    tl:tuple=None, # Optional top-left coordinate of the crop/pad, if `None` center crop
    orig_sz:tuple=None, # Original size of input
    pad_mode:PadMode=PadMode.Zeros, # Fastai padding mode
    resize_mode=BILINEAR, # Pillow `Image` resize mode
    resize_to:tuple=None # Optional post crop/pad resize of input
):
    if isinstance(sz,int): sz = (sz,sz)
    orig_sz = fastuple(_get_sz(x) if orig_sz is None else orig_sz)
    sz,tl = fastuple(sz),fastuple(((_get_sz(x)-sz)//2) if tl is None else tl)
    return x._do_crop_pad(sz, tl, orig_sz=orig_sz, pad_mode=pad_mode, resize_mode=resize_mode, resize_to=resize_to)
''')
#|export
def fix_nb_tuple_annots(nb):
    "Convert all fastai tuple style union annotations in `nb` to py310 union operator"
    for cell in nb.cells:
        try: cell.source = fix_tuple_annots(cell.source)
        except SyntaxError: pass
#|export
from fastcore.script import *
from fastcore.utils import *

@call_parse
def main(fname:str): # A notebook name or glob to convert
    "Convert all fastai tuple style union annotations in `nb_path` to py310 union operators"
    for f in globtastic(fname, file_glob='*.ipynb', skip_folder_re='^[_.]'):
        nb = read_nb(f)
        fix_nb_tuple_annots(nb)
        write_nb(nb, f)