Source code for py_entitymatching.matcherselector.mlmatcherselection

"""
This module contains functions for selecting a ML-matcher.
"""
import logging
from collections import OrderedDict

import pandas as pd
import numpy as np
from sklearn.model_selection import KFold, cross_val_score

from py_entitymatching.utils.catalog_helper import check_attrs_present
from py_entitymatching.utils.generic_helper import list_diff, list_drop_duplicates
from py_entitymatching.utils.validation_helper import validate_object_type

logger = logging.getLogger(__name__)


[docs]def select_matcher(matchers, x=None, y=None, table=None, exclude_attrs=None, target_attr=None, metric_to_select_matcher='precision', metrics_to_display=['precision', 'recall', 'f1'], k=5, n_jobs=-1, random_state=None): """ This function selects a matcher from a given list of matchers based on a given metric. Specifically, this function internally uses scikit-learn's cross validation function to select a matcher. There are two ways the user can call the fit method. First, interface similar to scikit-learn where the feature vectors and target attribute given as projected DataFrame. Second, give the DataFrame and explicitly specify the feature vectors (by specifying the attributes to be excluded) and the target attribute A point to note is all the input parameters have a default value of None. This is done to support both the interfaces in a single function. Args: matchers (MLMatcher): List of ML matchers to be selected from. x (DataFrame): Input feature vectors given as pandas DataFrame ( defaults to None). y (DatFrame): Input target attribute given as pandas DataFrame with a single column (defaults to None). table (DataFrame): Input pandas DataFrame containing feature vectors and target attribute (defaults to None). exclude_attrs (list): The list of attributes that should be excluded from the input table to get the feature vectors. target_attr (string): The target attribute in the input table (defaults to None). metric_to_select_matcher (string): The metric based on which the matchers must be selected. The string can be one of 'precision', 'recall', 'f1' (defaults to 'precision'). metrics_to_display (list): The metrics that will be displayed to the user. It should be a list of any of the strings 'precision', 'recall', or 'f1' (defaults to ['precision', 'recall', 'f1']). k (int): The k value for cross-validation (defaults to 5). n_jobs (integer): The number of CPUs to use to do the computation. -1 means 'all CPUs (defaults to -1)'. random_state (object): Pseudo random number generator that should be used for splitting the data into folds (defaults to None). Returns: A dictionary containing three keys - selected matcher, cv_stats, and drill_down_cv_stats. The selected matcher has a value that is a matcher (MLMatcher) object, cv_stats is a Dataframe containing average metrics for each matcher, and drill_down_cv_stats is a dictionary containing a table for each metric the user wants to display containing the score of the matchers for each fold. Raises: AssertionError: If `metric_to_select_matcher` is not one of 'precision', 'recall', or 'f1'. AssertionError: If each item in the list `metrics_to_display` is not one of 'precision', 'recall', or 'f1'. Examples: >>> dt = em.DTMatcher() >>> rf = em.RFMatcher() # train is the feature vector containing user labels >>> result = em.select_matcher(matchers=[dt, rf], table=train, exclude_attrs=['_id', 'ltable_id', 'rtable_id'], target_attr='gold_labels', k=5) """ # Check that metrics_to_display is valid if not isinstance(metrics_to_display, list): metrics_to_display = [metrics_to_display] for met in metrics_to_display: if met not in ['precision', 'recall', 'f1']: logger.error('Metrics must be either "precision", "recall", or "f1".') raise AssertionError('Metrics must be either "precision", "recall", or "f1".') # Check that metric_to_select_matcher is valid if metric_to_select_matcher not in ['precision', 'recall', 'f1']: logger.error('Metric must be either "precision", "recall", or "f1".') raise AssertionError('Metric must be either "precision", "recall", or "f1".') # Based on the input, get the x, y data that can be used to call the # scikit-learn's cross validation method x, y = _get_xy_data(x, y, table, exclude_attrs, target_attr) max_score = 0 # Initialize the best matcher. As of now set it to be the first matcher. sel_matcher = matchers[0] # Fix the header header = ['Name', 'Matcher', 'Num folds'] # # Append the folds fold_header = ['Fold ' + str(i + 1) for i in range(k)] header.extend(fold_header) # Finally, append the score. header.append('Mean score') # Initialize the drill_down_cv_stats dictionary drill_down_cv_stats = OrderedDict() # Initialize the cv_stats_dictionary and set the first key to contain the matcher names cv_stats_dict = OrderedDict() matcher_names = [] for m in matchers: matcher_names.append(m.get_name()) cv_stats_dict['Matcher'] = matcher_names # Run the cross_validation for each metric for met in metrics_to_display: dict_list = [] mean_score_list = [] # Run the cross validation for each matcher for m in matchers: # Use scikit learn's cross validation to get the matcher and the list # of scores (one for each fold). matcher, scores = cross_validation(m, x, y, met, k, random_state, n_jobs) # Fill a dictionary based on the matcher and the scores. val_list = [matcher.get_name(), matcher, k] val_list.extend(scores) val_list.append(np.mean(scores)) d = OrderedDict(zip(header, val_list)) dict_list.append(d) # Create a list of the mean scores for each matcher mean_score_list.append(np.mean(scores)) # Select the matcher based on the mean score, but only for metric_to_select_matcher if met == metric_to_select_matcher and np.mean(scores) > max_score: sel_matcher = m max_score = np.mean(scores) # Create a DataFrame based on the list of dictionaries created stats = pd.DataFrame(dict_list) stats = stats[header] drill_down_cv_stats[met] = stats # Add the mean scores for this metric to the cv_stats_dictionary cv_stats_dict['Average ' + met] = mean_score_list res = OrderedDict() # Add selected matcher and the stats to a dictionary. res['selected_matcher'] = sel_matcher res['cv_stats'] = pd.DataFrame(cv_stats_dict) res['drill_down_cv_stats'] = drill_down_cv_stats # Result the final dictionary containing selected matcher and the CV # statistics. return res
def cross_validation(matcher, x, y, metric, k, random_state, n_jobs): """ The function does cross validation for a single matcher """ # Use KFold function from scikit learn to create a ms object that can be # used for cross_val_score function. cv = KFold(k, shuffle=True, random_state=random_state) # Call the scikit-learn's cross_val_score function scores = cross_val_score(matcher.clf, x, y, scoring=metric, cv=cv, n_jobs=n_jobs) # Finally, return the matcher along with the scores. return matcher, scores def _get_xy_data(x, y, table, exclude_attrs, target_attr): """ Gets the X, Y data from the input based on the given table, the exclude attributes and target attribute provided. """ # If x and y is given, call appropriate function if x is not None and y is not None: return _get_xy_data_prj(x, y) # If table, exclude attributes and target attribute are given, # call appropriate function elif table is not None and exclude_attrs is not None \ and target_attr is not None: return _get_xy_data_ex(table, exclude_attrs, target_attr) else: # Else, raise a syntax error. raise SyntaxError('The arguments supplied does not match ' 'the signatures supported !!!') def _get_xy_data_prj(x, y): """ Gets X, Y that can be used for scikit-learn function based on given projected tables. """ # If the first column is '_id' the remove it if x.columns[0] == '_id': logger.warning( 'Input table contains "_id". Removing this column for processing') # Get the values from the DataFrame x = x.values x = np.delete(x, 0, 1) else: x = x.values if not isinstance(y, pd.Series): logger.error('Target attr is expected to be a pandas series') raise AssertionError('Target attr is expected to be a pandas series') else: # Get the values from the DataFrame y = y.values # Finally, return x an y return x, y def _get_xy_data_ex(table, exclude_attrs, target_attr): # Validate the input parameters # # We expect the input table to be of type pandas DataFrame validate_object_type(table, pd.DataFrame) # We expect exclude attributes to be of type list. If not convert it into # a list. if not isinstance(exclude_attrs, list): exclude_attrs = [exclude_attrs] # Check if the exclude attributes are present in the input table if not check_attrs_present(table, exclude_attrs): logger.error('The attributes mentioned in exclude_attrs ' 'is not present ' 'in the input table') raise AssertionError( 'The attributes mentioned in exclude_attrs ' 'is not present ' 'in the input table') # Check if the target attribute is present in the input table if not check_attrs_present(table, target_attr): logger.error('The target_attr is not present in the input table') raise AssertionError( 'The target_attr is not present in the input table') # Drop the duplicates from the exclude attributes exclude_attrs = list_drop_duplicates(exclude_attrs) # Explicitly add the target attribute to exclude attribute (if it is not # already present) if target_attr not in exclude_attrs: exclude_attrs.append(target_attr) # Project the list of attributes that should be used for scikit-learn's # functions. attrs_to_project = list_diff(list(table.columns), exclude_attrs) # Get the values for x x = table[attrs_to_project].values # Get the values for x y = table[target_attr].values y = y.ravel() # to mute warnings from svm and cross validation # Return x and y return x, y