Universal Experiment Tracker

7 minute read

Published:

Robust experiment tracker for algorithms running in loop (e.g. grid search). Prevents data-loss and allows for simple recovery on error. The search loops can be skipped by simply running a check(…) function for the set of parameters.

Supports time tracking by retrieving the last logged value. Provides easy to read console output.

User Functions: ExperimentTracker( path:Union[str, Path], tracker_name:str, clean_start:bool=False, raise_error_on_rewrite:bool=False, kwargs) log(results:dict, **kwargs) check(kwargs) -> bool get_next_log() -> dict def get_last_log() -> dict report()

NOTICE: **kwargs always refer to the independent variables (e.g. X and Y in grid search). They need to be provided with the same name as when passed to the constructor.

"""Universal Experiment Tracker
        
    AUTHOR: Martin Garaj (2023) garaj.martin@gmail.com
"""

import pandas as pd
from pathlib import Path
from typing import Union
from itertools import product
import numpy as np
import datetime
import warnings

class ExperimentTracker():
    """This class is responsible for handling logging of measured data for 
    experiments running a (grid) search.
    
    The class implements prevention of data-loss and automatic restart from the 
    last checkpoint (last logged value). 
    """
    def __init__(self, 
            path:Union[str, Path], 
            tracker_name:str, 
            clean_start:bool=False,
            raise_error_on_rewrite:bool=False, 
            **kwargs):
        self._path = Path(path) if isinstance(path, str) else path
        self._name = tracker_name
        self._fullpath = Path(self._path, self._name)
        self._vars = kwargs
        self._df = None
        self._raise_error = raise_error_on_rewrite
        
        if clean_start:
            self._create_new_tracker()
        else:
            if self._check_tracker_exists():
                self._load_tracker()
            else:
                self._create_new_tracker()


################################################################################
##                               USER FUNCTIONS                               ##
################################################################################ 
    def log(self, results:dict, **kwargs):
        """Log new results according to provided 'kwargs' -> indipendent 
        variables

        :param results: Dictionary of new results (single value per key)
        :type results: dict
        :raises ValueError: If 'kwargs' point to more than 1 entry
        """
        index = self._find_row_index(**kwargs)
        
        if isinstance(index, type(None)):
            self._append_to_df(results=results, **kwargs)
        elif isinstance(index, list):
            raise ValueError(f"Provided columns match {len(index)} rows.")
        else:
            if True == self._df.loc[index, '__processed__']:
                if self._raise_error:
                    raise RuntimeError(f"Raising error to prevent rewriting experiment results for ({str(self._get_row(index))})")
                else:
                    warnings.warn(f"Warning: rewriting existing entry at index={index} ({str(self._get_row(index))})")
            self._write_to_df(index, results)
            
        self._save_tracker()


    def check(self, **kwargs) -> bool:
        """Returns True if experiment has been logged/processed, else False

        :raises ValueError: Provided **kwargs do not sufficiently explain 
            indipendent variables
        :raises ValueError: Internal error when locating 'index'
        :return: True->epxeriment was run, False->experiment has not run
        :rtype: bool
        """
        if len(kwargs.keys()) == 0:
            raise RuntimeError(f"check(...) requires a list of parameters={', '.join(self._get_vars_columns())}")
        
        index = self._find_row_index(**kwargs)
        
        if isinstance(index, type(None)):
            return False
        elif isinstance(index, list):
            raise ValueError(f"Provided columns match {len(index)} rows.")
        elif isinstance(index, int):
            _row = self._get_row(index)
            return _row['__processed__']
        else:
            raise ValueError(f"index has unexpected type={type(index)}.")   
        
        
    def get_next_log(self) -> dict:
        """Return the first log row, where '__processed__' == False

        :return: If there is next row, returns dictionary, otherwise empty dict
        :rtype: dict
        """
        try:
            if self._df['__processed__'].all():
                return dict()
            else:
                index = (self._df['__processed__'] == False).idxmax()
                return self._get_row(index)
        except ValueError:
            return dict()        
        
        
    def get_last_log(self) -> dict:
        """Returns the latest log row

        :return: If there is latest entry, returns dictionary, 
            otherwise empty dict
        :rtype: dict
        """
        try:
            index = pd.to_datetime(self._df['__datetime__']).idxmax()
            _row = self._get_row(index)
            del _row['__datetime__']
            del _row['__processed__']
            return _row
        except ValueError:
            return dict()
        
        
    def report(self):
        print(f"========== EXPERIMENT TRACKER REPORT ==========")
        print(f">      name : {self._name}")
        print(f">      path : {self._fullpath.resolve()}")
        print(f">   records : {sum(self._df['__processed__'])}/{len(self._df.index)}")
        print(f"> variables : {', '.join(self._get_vars_columns())}")
        print(f">   results : {', '.join(self._get_result_columns())}")
        _datetime = self._get_last_log().get('__datetime__', 'N/A')
        current_dt = datetime.datetime.now()
        time_delta = current_dt - _datetime
        days = time_delta.days
        seconds = time_delta.seconds
        hours, remainder = divmod(seconds, 3600)
        minutes, seconds = divmod(remainder, 60)
        formatted_time_delta = f"{days}d {hours:02}h {minutes:02}m {seconds:02}s"
        print(f">  last log : {_datetime} ({formatted_time_delta} ago)")
        _last_vars = [f"'{key}':{val}" for key, val in self.get_last_log().items()]
        print(f">            {' , '.join(_last_vars)}")
        
        

################################################################################
##                             INTERNAL FUNCTIONS                             ##
################################################################################ 

    def _check_tracker_exists(self) -> bool:
        """Check if previous file exists.

        :raises ValueError: Path doesn't exist
        :return: True if path exists, otherwise False
        :rtype: bool
        """
        if not self._path.exists():
            raise ValueError(f"path={str(self._path)} doesn't exist")
        if self._fullpath.exists():
            return True
        else:
            return False
        
        
    def _create_new_tracker(self):
        """Create new tracker file.
        
        The dataframe incldues all '**kwargs' passed to __init__ as columns,
        extended by internal variables '__processed__' and '__datetime__'.
        """
        # filter out empty lists but remember all keys for columns
        non_empty_args = {k: v for k, v in self._vars.items() if v}
        all_keys = list(self._vars.keys())
        all_keys = ['__processed__', '__datetime__'] + all_keys

        # compute the Cartesian product of the non-empty variables
        cartesian_product = list(product(*non_empty_args.values()))
        
        if len(cartesian_product[0])==0:
            df = pd.DataFrame(columns=all_keys)
        else:
            # convert the product to a DataFrame
            df = pd.DataFrame(cartesian_product, columns=non_empty_args.keys())
            
            # reorder the columns to match the original order and add empty columns if necessary
            for key in all_keys:
                if key not in df.columns and len(df.index) > 0:
                    df[key] = np.nan
            df['__processed__'] = False

        self._df = df[all_keys]


    def _save_tracker(self):
        """Save the tracker file.
        """
        self._df.to_csv(path_or_buf=self._fullpath, index=False)


    def _load_tracker(self):
        """Load the tracker file.
        """
        self._df = pd.read_csv(filepath_or_buffer=self._fullpath)


    def _find_row_index(self, **kwargs) -> Union[int, list, type(None)]:
        """Finds mathcing row index according to provided values of 
        indipendent variables.

        :return: index value (int if unique, list if multiple matches, None if no match)
        :rtype: Union[int, list, type(None)]
        """
        # convert the kwargs into a query string
        query_str = ' and '.join([f"{k} == '{v}'" if isinstance(v, str) else f"{k} == {v}" for k, v in kwargs.items()])
        
        # use the query function to filter rows
        matching_rows = self._df.query(query_str)
        
        # return the indices
        indices = matching_rows.index.tolist()
        if not indices:
            return None
        elif len(indices) == 1:
            return indices[0]
        else:
            return indices


    def _append_to_df(self, results:dict, **kwargs):
        """Append new results.

        :param results: Dictionary of results (single value per key)
        :type results: dict
        """
        new_data = {**kwargs, **results, **{'__datetime__':self._get_datetime(), '__processed__':True}}
        df_new = pd.DataFrame([new_data])
        self._df = pd.concat([self._df, df_new], ignore_index=True)


    def _write_to_df(self, index:int, results:dict):
        """Write new value to internal df

        :param index: row index whereh the new data are written
        :type index: int
        :param results: Dictionary of results to be stored at the row.
        :type results: dict
        """
        for column, value in results.items():
            if column not in self._df.columns:
                self._df[column] = None
            self._df.loc[index, column] = value
        self._df.loc[index, '__datetime__'] = self._get_datetime() 
        self._df.loc[index, '__processed__'] = True


    def _get_datetime(self):
        """Get current date-time

        :return: datetime
        :rtype: datetime.datetime
        """
        return datetime.datetime.now()


    def _get_row(self, index) -> dict:
        """Get row with all columns

        :param index: Index of a row
        :type index: int
        :return: Dictionary of a row at index
        :rtype: dict
        """
        columns = list(self._df.columns)
        row_dict = self._df.loc[index, columns].to_dict()
        return row_dict


    def _get_vars_columns(self):
        """Get all columns associated with indipendent variables

        :return: List of columns associated with variables
        :rtype: list
        """
        vars_columns = list(self._vars.keys())
        return vars_columns


    def _get_result_columns(self) -> list:
        """Get all columns associated with results

        :return: List of columns associated with results
        :rtype: list
        """
        result_columns = []
        vars_columns = self._get_vars_columns()
        for column in self._df.columns:
            if column not in vars_columns and column not in ['__processed__', '__datetime__']:
                result_columns.append(column)
        return result_columns
            
            
    def _get_last_log(self) -> dict:
        """Returns the latest log row

        :return: If there is latest entry, returns dictionary, 
            otherwise empty dict
        :rtype: dict
        """
        try:
            index = pd.to_datetime(self._df['__datetime__']).idxmax()
            return self._get_row(index)
        except ValueError:
            return dict()

################################################################################
##                                     END                                    ##
################################################################################ 


if '__main__' == __name__:
    import time
    
    print('================================= RUNNING TEST =================================')
    # counter
    counter = 0
    # prepare experiment tracker
    var_A = [1,2,3]
    var_B = ['short', 'long', 'brief']
    # prepare tracker
    tracker = ExperimentTracker(path='./', tracker_name='experimnet_1', 
                                clean_start=True, 
                                raise_error_on_rewrite=True,
                                var_A=var_A, 
                                var_B=var_B,
                                )

    print('================ INTERNAL DF ================')
    print(tracker._df)
    print('---------------------------------------------')
    
    print('=================== LOOP 1 ==================')
    for a in var_A:
        for b in var_B:
            counter += 1
            time.sleep(0.4)
            tracker.log(results={'r_1':a*2+0.5, 'r_2':f"{a}{b}"}, var_A=a, var_B=b)
            
            if counter > 5:
                break
        if counter > 5:
            break
    print('---------------------------------------------')
    tracker.report()
            
    print('================ INTERNAL DF ================')
    print(tracker._df)
    print('---------------------------------------------')

    print('=================== LOOP 2 ==================')
    for a in var_A:
        for b in var_B:
            print(f"{a}-{str(b).ljust(8)} = {tracker.check(var_A=a, var_B=b)}")
            
            if tracker.check(var_A=a, var_B=b):
                print(f"     ALREADY LOGGED (SKIPPING)     {a}-{str(b).ljust(8)}")
            else:
                tracker.log(results={'r_1':a*2+0.5, 'r_2':f"{a}{b}"}, var_A=a, var_B=b)
                print(f"     LOGGING NEW EXPERIMENT        {a}-{str(b).ljust(8)}")
    print('---------------------------------------------')


    print('================ INTERNAL DF ================')
    print(tracker._df)
    print('---------------------------------------------')
    
    print('================ RAISE ERROR ================')
    tracker.log(results={'r_1':a*2+0.5, 'r_2':f"{a}{b}"}, var_A=a, var_B=b)