Filtering pandas data frames
In this lesson, you will learn how to filter rows of a pandas DataFrame based on specific conditions.
We already know how to extract rows from a pandas DataFrame using DataFrame.loc[]
.
For instance, here we slice rows with indexes 5 to 10 (inclusive):
import pandas as pd
df = pd.read_csv('/data/companies.csv')
# slice rows with index 5-10
print(df.loc[5:10])
Instead of using indexes, DataFrame.loc[]
also accepts a boolean array to extract specific rows.
One way to get such an array is by applying a condition to a column of our DataFrame.
For example, we could check if each company in our DataFrame was established after 2020:
import pandas as pd
df = pd.read_csv('/data/companies.csv')
# apply condition to 'Year Established' column
print(df['Year Established'] > 2020)
The result is an array of boolean values.
We can use this array to filter out rows where this condition does not apply.
To do this, we simply combine the boolean array with DataFrame.loc[]
:
import pandas as pd
df = pd.read_csv('/data/companies.csv')
# extract rows where condition is true
print(df.loc[df['Year Established'] > 2020])
Only 6 companies in our DataFrame were founded after 2020.
If we want to continue working with our filtered DataFrame, we can assign it to a new variable:
import pandas as pd
df = pd.read_csv('/data/companies.csv')
# create new filtered DataFrame
subset = df.loc[df['Year Established'] > 2020]
print(subset)
One thing to note about this newly filtered DataFrame is that the rows maintain their original index values.
This can be helpful in some cases but may also cause issues.
For example, if we want to extract the third row from our subset, we would expect the following syntax to work:
import pandas as pd
df = pd.read_csv('/data/companies.csv')
subset = df.loc[df['Year Established'] > 2020]
# extract 3rd row from subset
print(subset.loc[2])
However, we actually get an error message because there's no row with index 2.
In order to solve this problem, we can use the DataFrame.reset_index()
method:
import pandas as pd
df = pd.read_csv('/data/companies.csv')
# create new filtered DataFrame
subset = df.loc[df['Year Established'] > 2020]
# reset index values
subset = subset.reset_index()
print(subset)
Now, the old index is added as a column, and the rows have a new sequential index starting from 0.
To avoid the additional column with the old indices, use DataFrame.reset_index(drop=True)
.
Before we conclude this tutorial, I want to introduce another method for filtering pandas DataFrames: DataFrame.query()
Using DataFrame.query()
, the filtering we did earlier looks like this:
import pandas as pd
df = pd.read_csv('/data/companies.csv')
# filter using query
print(df.query('`Year Established` > 2020'))
The DataFrame.query()
method accepts a string representing a query expression.
The expression can contain column names and Python operators like <
, >
, and
, or
...
Instead of using df['column_name']
, we can refer to column names directly which makes code more more readable.
Note that we used backticks around our column name in the query expression: `Year Established`
.
This is because the label Year Established
contains a whitespace character and is therefore not a valid standalone Python variable name.
If the label is a valid Python variable name, you can remove the backticks.
Here we extract all rows from our companies DataFrame where the value of CEO is Melissa Robinson
:
import pandas as pd
df = pd.read_csv('/data/companies.csv')
# filter using query
print(df.query('CEO == "Melissa Robinson"'))
The DataFrame.query()
method is very flexible and can be used to easily combine multiple conditions.
Let's extract all companies with negative profit in the Technology sector:
import pandas as pd
df = pd.read_csv('/data/companies.csv')
# filter using query
print(df.query('`Revenue (USD)` < `Expenses (USD)` and Industry == "Technology"'))