diff --git a/breast_cancer_queries.py b/breast_cancer_queries.py index 79372f8638751c546916fd7d7be5d63451710bb8..5c67bae6e3346a119f494d6b7582e638c394878a 100644 --- a/breast_cancer_queries.py +++ b/breast_cancer_queries.py @@ -4,10 +4,12 @@ import csv import sys from hle import high_level_single # our code +# Load dataset from CSV file with open('data/breast-cancer.csv', 'r') as f: reader = csv.reader(f, delimiter=';') full_dataset = list(reader) +# Define feature names features = { 'clumpThickness': 'numeric', 'uniformityCellSize': 'numeric', @@ -20,16 +22,18 @@ features = { 'mitoses': 'numeric', } +# Define Class Names class_names = ['benign', 'melignant'] feature_names = list(features.keys()) feature_types = list(features.values()) -# because of binary features with values that are not 0 or 1. +# because of binary features with values that are not 0 or 1. (not needed here, leaving it just in case we need it) feature_mapping = { } +# Process row of features from dataset def process_features_student(row): to_delete = [0] cpy = [] @@ -43,21 +47,26 @@ def process_features_student(row): assert len(cpy) == len(feature_names) return cpy + +# Process Class Label def process_class(val): if float(val) >= 3: # good grade is a grade in [10, 20]. Bad grade is [0, 10) return 0 else: return 1 +# Prepare dataset by splitting features and labels dataset = full_dataset[1:] X = [ process_features_student(data[:-1]) for data in dataset] y = [ process_class(data[-1]) for data in dataset] +# Init and Train decision tree classifier cancer_clf = DecisionTreeClassifier(max_leaf_nodes=400, random_state=0) cancer_clf.fit(X, y) print('DecisionTreeClassifier has been trained') +# Example Queries (feel free to add more) q1 = 'exists p1, exists p2, benign(p1) implies benign(p2)' q2 = 'exists p1, exists p2, p1.blandChromatin > 3 and p2.marginalAdhesion <= 3 and melignant(p1) implies benign(p2)' q3 = 'for every patient, patient.blandChromatin > 4 implies melignant(patient)' @@ -70,7 +79,7 @@ q6 = ('exists p1, exists p2, p1.mitoses <= 2 implies melignant(p1)' 'and p2.blandChromatin > 9 implies p1.blandChromatin <= 3') - +# Eval Example Queries def example_queries(): queries = [q1,q2,q3,q4,q5,q6] avg = 0