Published unlisted
Edited
Jul 25, 2020
Insert cell
Insert cell
Insert cell
Insert cell
d3 = require("d3", "d3-array@^2.1");
Insert cell
z = require("zebras");
Insert cell
_ = require("lodash");
Insert cell
Plotly = require("https://cdn.plot.ly/plotly-latest.min.js");
Insert cell
ml = require("https://www.lactame.com/lib/ml/4.0.0/ml.min.js");
Insert cell
Insert cell
url = "https://raw.githubusercontent.com/BeirutAI/medical-cost-prediction/master/data/insurance.csv"
Insert cell
Insert cell
// read CSV file in d3.
data = d3.csv(url).then((data) => {
// parsing numbers.
return z.parseNums(["age", "bmi", "children", "charges"], data);
});
Insert cell
// display first 5 rows as an ASCII table.
z.printHead(5, data);
Insert cell
Insert cell
{
// get the number of rows and columns
let rows = z.shape(data)[0];
let columns = z.shape(data)[1];
return `There are ${rows} rows and ${columns} columns.`;
}
Insert cell
Insert cell
describeData(data, true);
Insert cell
Insert cell
charges = z.getCol("charges", data);
Insert cell
// plot the histogram of the charges
{
// create a div to for the chart.
let div = html`<div id="charges_chart"></div>`;
yield div;
// plot the charges.
let chart = new plot4(charges, "charges_chart");
chart.type("hist");
chart.title("Distribution of charges");
chart.xlabel("Charges");
chart.ylabel("Frequency");
chart.show();
}
Insert cell
Insert cell
// select smokers
smokers = data.filter(data => data.smoker == "yes");
Insert cell
// select non smokers
non_smokers = data.filter(data => data.smoker == "no");
Insert cell
`There are ${smokers.length} smokers and ${non_smokers.length} non-smokers.`
Insert cell
md`Now let's plot the charges for both.`
Insert cell
{
// get charges for both.
let smokersCharges = z.getCol("charges", smokers);
let nonSmokersCharges = z.getCol("charges", non_smokers);
// create a div for containing both div charts.
let div =
html`<div>
<div id="smoker-chart" style="width:500px; display:table-cell; border: 1px solid #000;"></div>
<div id="non-smoker-chart" style="width:500px; display:table-cell; border: 1px solid #000;"></div>
</div>`;
yield div;
// plot first one.
let chart1 = new plot4(smokersCharges, "smoker-chart");
chart1.type("hist");
chart1.title("Distribution of charges for smokers");
chart1.show();
// plot second one.
let chart2 = new plot4(nonSmokersCharges, "non-smoker-chart");
chart2.type("hist");
chart2.title("Distribution of charges for non-smokers");
chart2.show();
}
Insert cell
Insert cell
Insert cell
// get age data
age = z.getCol("age", data);
Insert cell
{
// plot histogram for age distribution
let div = html`<div id="age-chart" style="width: 500px;"></div>`;
yield div;
let chart = new plot4(age, "age-chart");
chart.type("hist");
chart.title("Age distribution");
chart.xlabel("Age");
chart.ylabel("Frequency");
chart.bins(20);
chart.show();
}
Insert cell
Insert cell
{
// draw a scatter plot to show correlation between age and charges
let div = html`<div id="scatter-age-chart" style="width: 500px;"></div>`;
yield div;
let chart = new plot4([age, charges], "scatter-age-chart");
chart.type("scatter");
chart.title("Cost of treatment for different ages");
chart.xlabel("Age");
chart.ylabel("Charges");
chart.show();
}
Insert cell
Insert cell
{
let smokersCharges = z.getCol("charges", smokers);
let nonSmokersCharges = z.getCol("charges", non_smokers);
let smokersAge = z.getCol("age", smokers);
let nonSmokersAge = z.getCol("age", non_smokers);
let div = html`<div id="scatter-age-smoke-chart" style="width: 540px;"></div>`;
yield div;
let chart = new plot4([
[smokersAge, smokersCharges, "smokers", "r"],
[nonSmokersAge, nonSmokersCharges, "non_smokers", "b"]
], "scatter-age-smoke-chart");
chart.type("scatter-m");
chart.title("Cost of treatment for different ages");
chart.xlabel("Age");
chart.ylabel("Charges");
chart.show();
}
Insert cell
md`#### Correlation between BMI and cost of treatment
# Body Mass Index
![alt text](https://4.bp.blogspot.com/-nBF9Z1tFGhI/W3MqbdD0j7I/AAAAAAAAAIs/UdyXTIxsBT8Pl8usABxEK_Fusj5S0SnBQCLcBGAs/s1600/HOW%2BTO%2BCALCULATE%2BBODY%2BMASS%2BINDEX%2BBMI.jpg)
# BMI Chart
![BMI char](https://images.squarespace-cdn.com/content/v1/56fae4be1d07c0c393d8faa5/1551103826935-HCXS8U78500C06GQ1PLJ/ke17ZwdGBToddI8pDm48kNMeyc_nGAbaGjp3EBJ2o08UqsxRUqqbr1mOJYKfIPR7LoDQ9mXPOjoJoqy81S2I8N_N4V1vUb5AoIIIbLZhVYxCRW4BPu10St3TBAUQYVKckzCNDuUMr1wTvf7-fqd2hrX5O2-_PoO3UJ2jNU1VzJbe6G9-F0r9BTATNUu-NBMy/BMI+Chart.jpg)
First, let's look at the distribution of BMI in our dataset, and then look at how it affects the cost of treatment.
`
Insert cell
// get bmi.
bmi = z.getCol("bmi", data);
Insert cell
{
// draw a histogram to show the distribution of BMI
let div = html`<div id="bmi_chart" style="width: 500px;"></div>`;
yield div;
let chart = new plot4(bmi, "bmi_chart");
chart.type("hist");
chart.title("BMI distribution");
chart.xlabel("BMI");
chart.ylabel("Frequency");
chart.show();
}
Insert cell
// select obese people
obese = data.filter(d => d.bmi >= 30);
Insert cell
// select overweight people
overweight = data.filter(d => (d.bmi >= 25 && d.bmi < 30));
Insert cell
// select healthy people
healthy = data.filter(d => d.bmi < 25);
Insert cell
`There are ${obese.length} obese, ${overweight.length} overweight and ${healthy.length} healthy individuals.`
Insert cell
Insert cell
{
let div = html`<div id="chart8" style="width: 500px;"></div>`;
yield div;
let obeseCharges = z.getCol("charges", obese);
let overweightCharges = z.getCol("charges", overweight);
let healthyCharges = z.getCol("charges", healthy);
let chart = new plot4([
[obeseCharges, "obese", "r"],
[overweightCharges, "overweight", "y"],
[healthyCharges, "healthy", "g"]
], "chart8");
chart.type("hist-m");
chart.title("Charges distribution");
chart.xlabel("Charges");
chart.ylabel("Frequency");
chart.ticks(false);
chart.show();
}
Insert cell
md`Patients with BMI above 30 spend more on treatment!`
Insert cell
Insert cell
// print how many missing value in each column.
describeData(data)["count"];
Insert cell
// find rows with missing values.
missing = {
let indexs = [];
data.forEach((row, i) => {
if (isNaN(row["bmi"])) {
indexs.push(i);
};
});
let rows = [];
indexs.forEach(i => {
let row = data[i];
row.index = i;
rows.push(row);
});
return rows;
}
Insert cell
// removing rows with missing values.
missing.forEach((row, i) => data.splice(row.index-i,1));
Insert cell
describeData(data, true);
Insert cell
md`#### Remove unused columns
Let's remove the \`region\` column since we don't really care about it
`
Insert cell
// dropping the region feature
copyOfdata = z.dropCol("region", data);
Insert cell
z.printHead(2, copyOfdata);
Insert cell
Insert cell
{
// define dictionary
let gender = {"male": 0, "female": 1};
// replace sex column with 0/1
copyOfdata.forEach(row => row.sex = gender[row.sex]);
// print head to verify
return z.printHead(5, copyOfdata)
}
Insert cell
{
// define dictionary
let smokers = {"no": 0, "yes": 1};
// replace smokers column with 0/1
copyOfdata.forEach(row => row.smoker = smokers[row.smoker]);
// print head to verify
return z.printHead(5, copyOfdata)
}
Insert cell
Insert cell
// get the max of each column
data_max = describeData(copyOfdata)["max"];
Insert cell
{
// divide each column by its maximum value
copyOfdata.forEach(row => {
let columns = Object.keys(copyOfdata[0]);
columns.forEach(column => row[column] = row[column] / data_max[column]);
});
return describeData(copyOfdata, true);
}
Insert cell
Insert cell
Insert cell
model_data = {
// store smoker column as input x.
let x = z.getCol("smoker",copyOfdata);
// store the charges column as the output (label) in y
let y = z.getCol("charges",copyOfdata);
// split dataset in a 80/20 split
let test_size = 0.2;
let [x_train, x_test] = _.chunk(x, data.length*(1-test_size));
let [y_train, y_test] = _.chunk(y, data.length*(1-test_size));
let train_data = [x_train, y_train];
let test_data = [x_test, y_test];
return {train_data, test_data};
}
Insert cell
// create our regression model
model = new ml.SimpleLinearRegression(model_data["train_data"][0], model_data["train_data"][1]);
Insert cell
// test our model
model.score(model_data["test_data"][0], model_data["test_data"][1]);
Insert cell
Insert cell
function describeData(data, options) {
let toStr = options || false;
let header = Object.keys(data[0]);
let acceptable = [];
header.forEach(head => {
if (typeof data[0][head] == "number") {
acceptable.push(head);
}
});
let names = ["count", "mean", "std", "min", "25%", "50%", "75%", "max"];
let stats = {};
names.forEach(elt => stats[elt]={});
acceptable.forEach(head => {
let count = d3.count(data, d => d[head]);
let mean = d3.mean(data, d => d[head]);
let std = d3.deviation(data, d => d[head]);
let [min, max] = d3.extent(data, d => d[head]);
let x25 = d3.quantile(data, 0.25, d => d[head]);
let x50 = d3.quantile(data, 0.5, d => d[head]);
let x75 = d3.quantile(data, 0.75, d => d[head]);
stats["count"][head]=count;
stats["mean"][head]=mean;
stats["std"][head]=std;
stats["min"][head]=min;
stats["25%"][head]=x25;
stats["50%"][head]=x50;
stats["75%"][head]=x75;
stats["max"][head]=max;
});
if (typeof toStr == "boolean" && toStr) {
let df = [];
names.forEach(name => {
let obj = {};
obj[""] = name;
df.push({...obj, ...stats[name]});
});
return z.print(df);
} else {
return stats;
}
}
Insert cell
class plot4 {
constructor(data, selector) {
this.data = data;
this.elt = selector;
this.colors = {
r: "#D62728",
b: "#3366CC",
g: "#109618",
y: "#F2B601",
o: "#F58518",
pl: "#990099",
lb: "#0099C6",
pk: "#FF9DA6",
lv: "#636EFA",
lo: "#EF553B",
default: "#1F77B4"
};
}
type(kind) {
this.type = kind;
}
title(t) {
this.Title = t;
}
xlabel(xl) {
this.xLabel = xl;
}
ylabel(yl) {
this. yLabel = yl;
}
ticks(bool) {
this.showTicks = bool;
}
bins(Bins) {
this.nBins = Bins;
}
show() {
if (this.type == "hist") {
let hist;
if (!this.nBins) {
hist = d3.histogram();
} else {
let scale = d3.scaleLinear()
.domain(d3.extent(this.data));
hist = d3.histogram()
.thresholds(scale.ticks(this.nBins));
}
let bins = hist(this.data);
let x = Array.from(bins, elt => elt = `${elt["x0"]}-${elt["x1"]}`);
let y = Array.from(bins, elt => elt = elt.length);
let trace0 = {
type: "bar",
x: x,
y: y,
width: Array(bins.length).fill(1)
};
let data = [trace0];
let layout = {
title: {text: this.Title || ""},
xaxis: {title: {text: this.xLabel || ""},
showticklabels: this.showTicks && true},
yaxis: {title: {text: this.yLabel || ""}}
};
Plotly.newPlot(this.elt, data, layout);
}
else if (this.type == "hist-m") {
let hist;
if (!this.nBins) {
hist = d3.histogram();
} else {
let scale = d3.scaleLinear()
.domain(d3.extent(this.data));
hist = d3.histogram()
.thresholds(scale.ticks(this.nBins));
}
let data = [];
this.data.forEach(part => {
let bins = hist(part[0]);
let x = Array.from(bins, elt => elt = `${elt["x0"]}-${elt["x1"]}`);
let y = Array.from(bins, elt => elt = elt.length);
let trace = {
type: "bar",
x: x,
y: y,
width: Array(bins.length).fill(1)
};
if(part[1]) {
trace.name = part[1];
}
if(part[2]) {
let color = this.colors[part[2]];
trace.marker = {color};
}
data.push(trace);
});
let layout = {
title: {text: this.Title || ""},
xaxis: {title: {text: this.xLabel || ""},
showticklabels: this.showTicks && true},
yaxis: {title: {text: this.yLabel || ""}}
};
Plotly.newPlot(this.elt, data, layout);
}
else if (this.type == "scatter") {
let x = this.data[0];
let y = this.data[1];
let trace0 = {
type: "scatter",
x: x,
y: y,
mode: "markers"
};
let data = [trace0];
let layout = {
title: {text: this.Title || ""},
xaxis: {title: {text: this.xLabel || ""},
showticklabels: this.showTicks && true},
yaxis: {title: {text: this.yLabel || ""}}
};
Plotly.newPlot(this.elt, data, layout);
}
else if (this.type == "scatter-m") {
let data = [];
this.data.forEach(part => {
let x = part[0];
let y = part[1];
let trace = {
type: "scatter",
x: x,
y: y,
mode: "markers"
};
if(part[2]) {
trace.name = part[2];
}
if(part[3]) {
let color = this.colors[part[3]];
trace.marker = {color};
}
data.push(trace);
});
let layout = {
title: {text: this.Title || ""},
xaxis: {title: {text: this.xLabel || ""},
showticklabels: this.showTicks && true},
yaxis: {title: {text: this.yLabel || ""}}
};
Plotly.newPlot(this.elt, data, layout);
}
}
}
Insert cell
z.shape = (data) => {
let rows = data.length;
let columns = Object.keys(data[0]).length;
return [rows, columns];
}
Insert cell