Feature importance when using embeddings as input
Table of Contents
Checking the Feature Importance when using Embedding as input features.
1. Multi-class text classification
1.1. Imports
import numpy as np import pandas as pd from collections import defaultdict from sentence_transformers import SentenceTransformer from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split from lime.lime_text import LimeTextExplainer # Load E5 Embedding Model model = SentenceTransformer("intfloat/e5-base") # You can use another E5 model
1.2. Data
texts = [ # Economy & Politics (0) "The stock market surged after a strong jobs report.", "New economic policies aim to reduce inflation.", "The central bank raised interest rates again this month.", "The trade war between the two countries is escalating.", "A new tax reform bill has been proposed in Congress.", "The GDP growth rate exceeded analysts' expectations.", "Government spending on infrastructure projects increased.", "The housing market is experiencing a downturn.", "The president announced a new foreign trade deal.", "Global oil prices are affecting national economies.", "Stock investors are cautious due to market volatility.", "The unemployment rate dropped to a five-year low.", "A new bill proposes tax incentives for small businesses.", "The finance minister announced new banking regulations.", "Consumer confidence in the economy is improving.", # Technology & AI (1) "A breakthrough in quantum computing was announced today.", "Artificial intelligence is transforming customer service.", "New cybersecurity threats are emerging with deepfake technology.", "The latest smartphone features an advanced neural chip.", "A major software update improved the performance of self-driving cars.", "Python remains the top choice for AI and data science.", "Researchers developed an AI system that can generate human-like text.", "Cloud computing services are expanding globally.", "Tech companies are investing heavily in augmented reality.", "A startup built an AI-powered legal assistant.", "Data privacy concerns are growing with new surveillance tech.", "Blockchain technology is being used beyond cryptocurrencies.", "The demand for skilled machine learning engineers is rising.", "Big Tech firms are competing in the generative AI race.", "Open-source software is driving innovation in AI research.", # Sports (2) "The World Cup final was watched by millions worldwide.", "The NBA playoffs are delivering intense matchups.", "A new world record was set in the 100m sprint.", "The top tennis player won another Grand Slam title.", "The Super Bowl halftime show featured an unforgettable performance.", "An underdog team won the national soccer championship.", "The Formula 1 season is seeing unexpected podium finishes.", "The new football season has started with exciting games.", "Olympic athletes are preparing for the upcoming Games.", "The boxing match between the two champions ended in a knockout.", "A rising star in golf won his first major tournament.", "The latest UFC fight card featured thrilling bouts.", "The baseball league introduced new rules to speed up the game.", "The Tour de France saw a dramatic finish in the final stage.", "A major controversy emerged over doping in track and field.", ] labels = [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # Economy/Politics 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, # Tech/AI 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, # Sports ] # Mapping numeric labels to string class names class_names = { 0: "Economy/Politics", 1: "Tech/AI", 2: "Sports" }
1.3. Training a model
# Encode Texts into E5 Embeddings X = np.array(model.encode(texts)) # Convert sentences to vector embeddings y = np.array(labels) # Convert labels to numpy array # Train-Test Split X_train, X_test, y_train, y_test, text_train, text_test = train_test_split( X, y, texts, test_size=0.3, random_state=42 ) # Train Logistic Regression Model clf = LogisticRegression(max_iter=1000, solver='lbfgs') clf.fit(X_train, y_train)
1.4. Explaning one example
# Define a Prediction Function for LIME def predict_proba(text_list): embeddings = np.array(model.encode(text_list)) # Convert text to E5 embeddings return clf.predict_proba(embeddings) # Return probability distribution over classes # Initialize LIME Explainer explainer = LimeTextExplainer(class_names=class_names) # Explain a Test Example (Change index to analyze different test samples) idx = 0 # Index of test example text_sample = text_test[idx] print(f"Text Sample: {text_sample}") print(f"True Label: {y_test[idx]}") print(f"Predictions: {predict_proba([text_sample])}") # Generate LIME Explanation exp = explainer.explain_instance(text_sample, predict_proba, num_features=10) # Show Important Words for Prediction # exp.show_in_notebook() # Uncomment if running in a Jupyter Notebook print(exp.as_list()) # Print word importance scores
Text Sample: The boxing match between the two champions ended in a knockout. True Label: 2 Predictions: [[0.27847295 0.22740878 0.49411826]] [(np.str_('ended'), -0.019836899479228738), (np.str_('champions'), -0.016320070944142794), (np.str_('match'), -0.012017644251766983), (np.str_('boxing'), -0.011215725851938212), (np.str_('The'), -0.005004482383459089), (np.str_('two'), -0.0039665240041576195), (np.str_('between'), -0.0035507443198058678), (np.str_('knockout'), -0.0021179273721672527), (np.str_('in'), -0.0020914076748228707), (np.str_('the'), -0.0008708649982305064)]
1.5. Aggregate by class
# Initialize LIME Explainer explainer = LimeTextExplainer(class_names=list(class_names.values())) # Dictionary to accumulate importance scores for each class class_importance = defaultdict(lambda: defaultdict(list)) # Iterate Over Test Samples for idx, text_sample in enumerate(text_test): exp = explainer.explain_instance(text_sample, predict_proba, num_features=10) pred_class = clf.predict([X_test[idx]])[0] # Get the predicted class index class_name = class_names[pred_class] # Convert index to class name # Store importance scores for each word for word, score in exp.as_list(): class_importance[class_name][word].append(score) # Aggregate: Compute Mean Importance for Each Class class_aggregate = { cls: {word: np.mean(scores) for word, scores in word_dict.items()} for cls, word_dict in class_importance.items() } # Convert to DataFrame for Easier Visualization df = pd.DataFrame(class_aggregate).fillna(0) # Display Top 10 Important Words Per Class for c in df.columns: print(f"\nTop words for {c}:") print(df[c].sort_values(ascending=False).head(10))
Top words for Sports: latest 0.022434 new 0.010803 UFC 0.005392 A 0.004181 de 0.001716 card 0.001536 world 0.000266 cars 0.000000 major 0.000000 of 0.000000 Name: Sports, dtype: float64 Top words for Tech/AI: AI 0.091387 software 0.046983 technology 0.045928 tech 0.045213 self 0.037621 driving 0.027930 startup 0.023201 Data 0.021444 assistant 0.017712 built 0.016041 Name: Tech/AI, dtype: float64 Top words for Economy/Politics: Congress 0.013226 projects 0.010172 is 0.009915 proposed 0.009739 proposes 0.007490 businesses 0.006913 new 0.006638 foreign 0.006551 infrastructure 0.004503 has 0.002549 Name: Economy/Politics, dtype: float64