Minhash LSH Implementation Walkthrough: Deduplication
MinHash Locality Sensitive Hashing (LSH) is a technique used for approximate nearest neighbor search in high-dimensional spaces.
Join the DZone community and get the full member experience.
Join For FreeMinHash Locality Sensitive Hashing (LSH) is a technique used for approximate nearest neighbor search in high-dimensional spaces. It is commonly used in tasks such as near-duplicate detection, recommendation systems, and clustering over a huge amount of data whereas exact nearest neighbor algorithms can provide higher accuracy but computationally they are pretty heavy and time-consuming.
As an additional note, Minhash LSH doesn't work based on semantic search, it is more of a sequential word search.
There are already lots of well-written documents circulating on Google that provide an in-depth explanation of the overall process. For reference, I am attaching a couple of links:
Here's a quick summary of the algorithm along with high-level implementation steps:
MinHashing
- Generate a set of hash functions called MinHash functions.
- Represent each item in the dataset as a set of characteristic features (e.g., shingles, tokens).
- For each feature set, compute its MinHash signature by hashing each feature through the MinHash functions and retaining the minimum hash value for each function.
- The resulting MinHash signatures represent the items in a lower-dimensional space while preserving similarity between items.
Locality Sensitive Hashing (LSH)
- Divide the MinHash signatures into bands and hash each band separately.
- Items with similar signatures will hash to the same bucket with high probability.
- By comparing only the items within the same bucket, approximate nearest neighbors can be efficiently identified.
MinHashing
- Define the number of hash functions (permutations) to use for MinHashing.
- Preprocess the dataset to extract characteristic features from each item.
- Compute the MinHash signature for each item by hashing its features through the MinHash functions and retaining the minimum hash value for each function.
Locality Sensitive Hashing (LSH)
- Divide the MinHash signatures into bands, each consisting of multiple rows.
- Hash each band separately, and group items with the same hash value into candidate buckets.
- Optionally, apply multiple hash tables for better accuracy.
For Similarity Search
- Given a query item, compute its MinHash signature.
- Hash the signature into the corresponding buckets.
- Retrieve candidate items from the buckets and perform an exact similarity comparison.
Post-Processing
- If necessary, perform exact similarity computation on the candidate items to eliminate false positives and obtain the final result.
Optimizations
- Adjust the number of bands and rows per band to balance between recall and efficiency.
- Experiment with different hash functions and parameters to optimize performance for specific datasets and similarity thresholds.
Package Installations
Datasketch library is a powerful tool for approximate computation and analysis of large datasets, with a focus on efficiency, scalability, and accuracy. It is widely used in various domains such as data mining, machine learning, and information retrieval.
!pip install recordlinkage
!pip install datasketch
Import Required Libraries
import numpy as np
import pandas as pd
import re
import time
from datasketch import MinHash, MinHashLSHForest
import recordlinkage
from recordlinkage.datasets import load_febrl1
This code loads an in-built dataset (febrl1
) using the load_febrl1()
function from the recordlinkage.datasets module. It then merges all columns of the dataset into a single column, except for the columns specified in the exclude_columns list. Finally, the merged column is added to the original DataFrame as a new column named "text."
# Loading and using the in-built dataset
df = load_febrl1()
# Columns to exclude from merging
exclude_columns = []
# Merge all columns into a single column except for the excluded columns
merged_column = df.drop(columns=exclude_columns).apply(lambda x: ' '.join(x.astype(str)), axis=1)
# Add the merged column to the DataFrame
df["text"]=merged_column
df.head()
df.info()
This code shuffles the DataFrame df using the sample()
function with a fixed random state for reproducibility (random_state=42
). Then, it calculates the number of rows for training and testing sets based on the specified proportions (99% for training and 1% for testing).
Next, it splits the shuffled DataFrame into training and testing sets using the iloc
function to select rows up to the calculated indices for the training set (df_shuffled.iloc[:train_size]
) and the remaining rows for the testing set (df_shuffled.iloc[train_size:]
).
Finally, it resets the index for both the training and testing DataFrames using the reset_index()
function with drop=True
to discard the old index and replace it with a new sequential index. This step is optional.
# Shuffle the DataFrame
df_shuffled = df.sample(frac=1, random_state=42) # Shuffle with a fixed random state for reproducibility
# Calculate the number of rows for training and testing
train_size = int(0.99 * len(df))
test_size = len(df) - train_size
# Split the shuffled DataFrame into training and testing sets
df_train = df_shuffled.iloc[:train_size]
df_test = df_shuffled.iloc[train_size:]
# Reset index for both DataFrames if needed (Optional)
df_train.reset_index(drop=True, inplace=True)
df_test.reset_index(drop=True, inplace=True)
The below function tokenize_per_n_characters
takes a text input and tokenizes it into n-character chunks.
def tokenize_per_n_characters(text,n):
tokens = []
for i in range(len(text) - (n - 1)):
tokens.append("'" + text[i:i+n] + "'")
return tokens
This preprocess_text
function is designed to preprocess a text input.
- Lowercasing: It converts the entire text to lowercase using the
lower()
method. - Removing punctuation, '
@
' symbols, and dots: It defines a string variableremove_chars
containing all punctuation symbols, '@
' symbols, and dots. It then uses thetranslate()
methodstr.maketrans()
to remove these characters from the text. - Tokenization: It calls a
tokenize_per_n_characters()
function to tokenize the preprocessed text into tokens of a specified length n. Thetokenize_per_n_characters()
function is assumed to be defined elsewhere in the code.
import string
def preprocess_text(text):
# Lowercasing
text = str(text).lower()
# Removing punctuation, @ symbols, and dots
remove_chars = string.punctuation + '@.'
text = text.translate(str.maketrans('', '', remove_chars))
# Whitespace tokenization
tokens = tokenize_per_n_characters(text,n)
return tokens
This get_forest
function is designed to build an LSH Forest for a given dataset.
Input Parameters
- data: The dataset containing the text to be hashed. In this case, it will be
df_train
- perms: The number of permutations for the MinHash. The key idea behind MinHash is to use permutations of the elements in the set to create different hash functions. Each permutation defines a unique hash function. The number of permutations used in MinHash determines the accuracy of similarity estimation. More permutations generally lead to better accuracy but require more computational resources.
- bands: The number of bands to divide the MinHash signature into.
- rows_per_band: The number of rows per band for LSH.
Process
- Tokenization: It iterates through each text in the dataset and preprocesses it using the preprocess_text function.
- Hashing: Then, it generates MinHash signatures for each text using the specified number of permutations (perms). These MinHash signatures are added to a list.
- Building the LSH forest: It initializes an LSH Forest with the specified number of permutations (perms). Then, it adds each MinHash signature to the LSH Forest.
- Indexing: After adding all MinHash signatures, it indexes the LSH Forest.
- Querying for candidate pairs: It queries the LSH Forest to find candidate pairs of similar texts. For each MinHash signature, it retrieves a specified number of candidate results (num_results) from the LSH Forest.
- Grouping candidate pairs into bands: It groups candidate pairs into bands based on their hash values. This is done by hashing the first band values of the MinHash signature.
- Printing bands with multiple pairs: It prints the bands along with the pairs of similar texts found within each band.
Output
- It returns the constructed LSH Forest (
forest
) and a dictionary (bands_dict
) containing bands as keys and pairs of similar texts as values. - Additional notes: Indexing in an LSH (Locality Sensitive Hashing) Forest involves organizing the data into hash tables based on their hash values. Each hash table contains buckets, and each bucket stores references to data points with similar hash values. This indexing process allows for efficient retrieval of approximate nearest neighbors during query operations. The key idea is to group similar data points together based on their hash signatures, enabling faster search and retrieval compared to brute-force methods, especially in high-dimensional spaces.
def get_forest(data, perms,bands, rows_per_band):
start_time = time.time()
minhash = []
for text in data['text']:
tokens = preprocess_text(text)
m = MinHash(num_perm=perms)
#print("value of m is ", str(m))
for s in tokens:
m.update(s.encode('utf8'))
minhash.append(m)
#for i,n in enumerate(minhash):
#print("MinHash {}: hash values = {}".format(i+1, n.hashvalues))
forest = MinHashLSHForest(num_perm=perms)
for i,m in enumerate(minhash):
forest.add(i,m)
forest.index()
# Query LSH Forest to find candidate pairs
candidate_pairs = {}
num_results=rows_per_band
for i, m in enumerate(minhash):
result = forest.query(m,num_results )
for j in result:
if i < j: # Ensure no duplicate pairs
candidate_pairs[(i, j)] = True
# Group candidate pairs into bands based on their hash values
bands_dict = {}
for (row1, row2) in candidate_pairs.keys():
band_hash = hash(tuple(sorted(minhash[row1].hashvalues[:bands])))
if band_hash not in bands_dict:
bands_dict[band_hash] = []
bands_dict[band_hash].append((row1, row2))
# Print bands with multiple pairs
for band, pairs in bands_dict.items():
if len(pairs) > 1:
print("Band hash:", band)
print("Pairs:", pairs)
print('It took %s seconds to build forest.' %(time.time()-start_time))
return forest,bands_dict
Sample Output: Band hash: -7380833634571281130 Pairs: [(9, 978), (402, 823)] Band hash: -1902129179255886798 Pairs: [(16, 727), (255, 788), (733, 879)]
Band hash
: It refers to the hash value calculated based on the MinHash signature of the records in a specific band. Each band consists of multiple MinHash values, and the band hash is computed by hashing these values. It helps in grouping similar records into bands efficiently.Pairs
: It is a list of tuples, where each tuple represents a pair of indices referring to records in the dataset. These indices indicate the positions of the similar records in the original dataset. For example, (9, 978) indicates that the records at positions 9 and 978 in the dataset are similar, and (402, 823) indicates that the records at positions 402 and 823 are similar.
#Number of Permutations
permutations = 128
bands=2
rows_per_band=2
f,b=get_forest(df_train, permutations,bands, rows_per_band)
The next function iterates over the pairs of similar records stored in the bands_dict
dictionary, retrieves the corresponding records from the data DataFrame using their indices, and appends them to a list of similar records. Finally, it returns the list of similar records in the existing training dataset.
def check_similar_records(bands_dict, data):
similar_records = []
for pairs in bands_dict.values():
for (row1, row2) in pairs:
record1 = data.iloc[row1]
record2 = data.iloc[row2]
similar_records.append((record1, record2))
return similar_records
# usage
similar_records = check_similar_records(b, df_train)
print(len(similar_records)) ####It has identified 491 similar records
print(similar_records[1])
The predict function allows you to find similar records in a database based on a given text input efficiently using MinHash and LSH forest indexing.
- Preprocessing: It preprocesses the input text by tokenizing it and generating a MinHash object based on the tokens.
- Querying the LSH forest: It queries the LSH forest with the generated MinHash object to find similar records in the database.
- Retrieving results: It retrieves similar records from the database based on the indices returned by the LSH forest query.
- Output: It returns similar records found in the database based on the input text.
def predict(text, database, perms, num_results, forest):
start_time = time.time()
tokens = preprocess_text(text)
m = MinHash(num_perm=perms)
for s in tokens:
m.update(s.encode('utf8'))
idx_array = np.array(forest.query(m, num_results))
if len(idx_array) == 0:
return None # if your query is empty, return none
result = database.iloc[idx_array]['text']
#print('It took %s seconds to query forest.' %(time.time()-start_time))
return result
rec: It is the actual text for which we are finding the matching records in the forest (f
) created via df_train
dataset. The permutation value will remain the same.
#Number of Recommendations to return
num_recommendations = 1
df_test_small=df_test.head(5)
df_test_small["text"].head()
#rec="nichol prideaux 9 hicks street loccation 7229 bondi junction 5011 sa 19420429 1107619"
for rec in df_test_small["text"]:
print("Actual records is ", rec , "\n")
res=predict(rec, df_train, permutations, num_recommendations, f)
print("Similar records is ", res)
print("\n\n")
Actual records is marianne rees 6 nerli place weeroona mansfield 3644 sa 19770628 4235561 Similar records is 789 rees maribanne 6 nerli place weeroona mansfield 3644 sa 19770628 4235561 Name: text, dtype: object Actual records is isabella ayres 1075 arthur circle grevilla est glenwood 5112 qld 19531230 9208787 Similar records is 390 isabella ayrse 1075 arthur circle grevilal est glewnood 5121 qld 19531230 9208787 Name: text, dtype: object Actual records is kirra brock 93 clarnette place backwoodlands orbost 2528 vic 19231209 5107876 Similar records is 407 kirar brock 93 clarnette place nan orbost 2528 vic 19231209 5107876 Name: text, dtype: object Actual records is christina coleman 506 hall street rsde 817 wallace woolgoolga 3073 qld 19070811 6822993 Similar records is 457 christina colemn 506 hall street rsde 817 wallace woolgoolga 3063 qld 19070811 6822993 Name: text, dtype: object Actual records is joshua rickett 1 burraly court malladup new farm 2439 act 19030121 4310453 Similar records is 486 joshua rickett 1 burraly corut malladuo belleve hill 2439 act 19030121 4310453 Name: text, dtype: object
The function below allows you to incrementally update the LSH forest with new text records and identify any new bands created by these updates.
- Preprocessing: It preprocesses the new text by tokenizing it.
- Creating MinHash: It generates a MinHash object for the new text based on the tokens.
- Adding MinHash to the forest: It adds the new MinHash to the LSH forest.
- Querying the forest: It queries the LSH forest to find candidate pairs for the new MinHash.
- Grouping into bands: It groups the candidate pairs into bands based on their hash values.
- Printing new bands: If the new record creates a band with multiple pairs, it prints out the band hash and the pairs.
def update_forest(forest, bands_dict, new_text, perms, bands, rows_per_band):
# Preprocess the new text
tokens = preprocess_text(new_text)
# Create MinHash for the new text
new_minhash = MinHash(num_perm=perms)
for s in tokens:
new_minhash.update(s.encode('utf8'))
new_index = "m6"
# Add the new MinHash to the forest
forest.add(new_index, new_minhash)
forest.index()
# Query LSH Forest to find candidate pairs with the new MinHash
num_results=rows_per_band
result = forest.query(new_minhash, num_results)
# Group candidate pairs into bands based on their hash values
new_band_hash = hash(tuple(sorted(new_minhash.hashvalues[:bands])))
if new_band_hash not in bands_dict:
bands_dict[new_band_hash] = []
bands_dict[new_band_hash].append(new_index) # Append the index of the new record
# Print the new bands with multiple pairs
if len(result) > 1:
print("New record created a band with multiple pairs:")
print("Band hash:", new_band_hash)
print("Pairs:", [(new_index, idx) for idx in result if idx < new_index])
There is a caveat with this function with index(new_index
). It looks like every time you need to provide a different value or you can automate it by producing some unique hash or UUID. But for now, I have left it as it is to manually change the value for each of the new data inserts.
new_text="nicholas pride 9 hicks street loccn 7229 bond junction 5011 sa 19420429 1107619"
update_forest(forest, b, new_text, permutations, bands, rows_per_band)
New record created a band with multiple pairs: Band hash: -8713188616659143193 Pairs: [('m6', 'm5'), ('m6', 'm2')]
Opinions expressed by DZone contributors are their own.
Comments