Source code for robustx.generators.CEGenerator

from abc import ABC, abstractmethod
import pandas as pd
from robustx.lib.tasks.Task import Task


[docs] class CEGenerator(ABC): """ Abstract class for generating counterfactual explanations for a given task. This class provides a framework for generating counterfactuals based on a distance function and a given task. It supports default distance functions such as Euclidean and Manhattan, and allows for custom distance functions. Attributes: _task (Task): The task to solve. __customFunc (callable, optional): A custom distance function. """ def __init__(self, ct: Task, custom_distance_func=None): """ Initializes the CEGenerator with a task and an optional custom distance function. @param ct: The Task instance to solve. @param custom_distance_func: An optional custom distance function. """ self._task = ct self.__customFunc = custom_distance_func @property def task(self): return self._task
[docs] def generate(self, instances: pd.DataFrame, neg_value=0, column_name="target", **kwargs) -> pd.DataFrame: """ Generates counterfactuals for a given DataFrame of instances. @param instances: A DataFrame of instances for which you want to generate counterfactuals explanations. @param distance_func: The method to calculate the distance between two points. Options are 'l1' / 'manhattan', 'l2' / 'euclidean', and 'custom'. @param column_name: The name of the target column. @param neg_value: The value considered negative in the target variable. @return: A DataFrame of the counterfactual explanations for the provided instances. """ cs = [] for _, instance in instances.iterrows(): cs.append(self.generate_for_instance(instance, neg_value=neg_value, column_name=column_name, **kwargs)) res = pd.concat(cs) return res
[docs] def generate_for_instance(self, instance, neg_value=0, column_name="target", **kwargs) -> pd.DataFrame: """ Generates a counterfactual for a provided instance. @param instance: The instance for which you would like to generate a counterfactual. @param distance_func: The method to calculate the distance between two points. Options are 'l1' / 'manhattan', 'l2' / 'euclidean', and 'custom'. @param column_name: The name of the target column. @param neg_value: The value considered negative in the target variable. @return: A DataFrame containing the counterfactual explanations for the instance. """ return self._generation_method(instance, neg_value=neg_value, column_name=column_name, **kwargs)
[docs] def generate_for_all(self, neg_value=0, column_name="target", **kwargs) -> pd.DataFrame: """ Generates counterfactuals for all instances with a given negative value in their target column. @param neg_value: The value in the target column which counts as a negative instance. @param column_name: The name of the target variable. @param distance_func: The method to calculate the distance between two points. Options are 'l1' / 'manhattan', 'l2' / 'euclidean', and 'custom'. @return: A DataFrame of the counterfactuals for all negative values. """ negatives = self.task.get_negative_instances(neg_value, column_name=column_name) # preds = self.task.model.predict(self.task.training_data.X).values.flatten() # # if neg_value == 0: # idxs = np.where(preds < 0.5)[0] # negatives = self.task.training_data.data.drop(columns=[column_name]) # negatives = pd.DataFrame(negatives.values[idxs], columns=negatives.columns) # else: # idxs = np.where(preds >= 0.5)[0] # negatives = self.task.training_data.data.drop(columns=[column_name]) # negatives = pd.DataFrame(negatives.values[idxs], columns=negatives.columns) counterfactuals = self.generate( negatives, column_name=column_name, neg_value=neg_value, **kwargs ) counterfactuals.index = negatives.index return counterfactuals
@abstractmethod def _generation_method(self, instance, column_name="target", neg_value=0, **kwargs): """ Abstract method to be implemented by subclasses for generating counterfactuals. @param instance: The instance for which to generate a counterfactual. @param distance_func: The function to calculate distances. @param column_name: The name of the target column. @param neg_value: The value considered negative in the target variable. @return: A DataFrame containing the generated counterfactuals. """ pass @property def custom_distance_func(self): """ Returns custom distance function passed at instantiation @return: distance Function, (DataFrame, DataFrame) -> Int """ return self.__customFunc