import copy
import logging
import os
import tempfile
from typing import Union
import pandas as pd
from requests_toolbelt import MultipartEncoder
from modep_client.client import Client
logger = logging.getLogger(__name__)
[docs]class Datasets:
def __init__(self, client: Client):
"""
Initialize the Datasets class
:param client: A :class:`modep_client.client.Client` object
"""
self.client = client
[docs] def upload(
self,
dset: Union[str, pd.DataFrame],
name: str,
target: str = None,
categorical_target: bool = True,
):
"""
Upload a tabular dataset.
:param dset: either a path to a CSV file or DataFrame containing the data
:type dset: str or :class:`pandas.DataFrame`
:param str name: A name to give the dataset (ie. `titanic-train` or `titanic-test`)
:param target: Optionally specify a target column for the dataset
:type target: str or None
:param bool categorical_target: `True` if the specified `target` column is categorical
(for classification), otherwise set this to `False` for regression.
"""
if isinstance(dset, str):
path = dset
if not os.path.exists(path):
raise Exception(f"Path does not exist: '{path}'")
elif isinstance(dset, pd.DataFrame):
path = tempfile.NamedTemporaryFile(suffix="-df-upload").name + ".csv"
logger.info("Writing DataFrame to %s", path)
dset.to_csv(path, index=False)
else:
raise ValueError(
"Unknown type for dataset, "
"must be either string path or pd.DataFrame"
)
logger.info("Uploading from %s", path)
url = self.client.url + "datasets/tabular"
# deepcopy since we update headers below
headers = copy.deepcopy(self.client.auth_header())
with open(path, "rb") as f:
data = MultipartEncoder(
{
"path": os.path.abspath(path),
"name": name,
"file": (path, f, "text/csv/h5"),
"target": target,
"categorical_target": str(categorical_target),
}
)
headers.update(
{"Prefer": "respond-async", "Content-Type": data.content_type}
)
resp = self.client.sess.post(url, data=data, headers=headers)
if resp.ok:
return resp.json()
else:
self.client.response_exception(resp)
[docs] def get(self, id: str):
"""
Get a dataset by id.
:param str id: The id of the dataset
:return: A dictionary for the dataset
"""
url = self.client.url + "datasets/tabular/" + str(id)
resp = self.client.sess.get(url, headers=self.client.auth_header())
if resp.ok:
return resp.json()
else:
self.client.response_exception(resp)
[docs] def list(self):
"""
List all datasets.
:return: A list of dictionaries for each uploaded or public dataset
"""
url = self.client.url + "datasets/tabular"
resp = self.client.sess.get(url, headers=self.client.auth_header())
if resp.ok:
js = resp.json()
df = pd.DataFrame(js)
if len(df) > 0:
# keep column order same as json
df = df[list(js[0].keys())].set_index("id")
df = df.sort_values(by="created", ascending=False)
return df
else:
self.client.response_exception(resp)
[docs] def delete(self, dataset_id):
"""
Delete a dataset by id.
:param str id: The id of the dataset
:return: A dictionary containing information on the deletion
"""
url = self.client.url + "datasets/tabular/" + str(dataset_id)
resp = self.client.sess.delete(url, headers=self.client.auth_header())
if resp.ok:
return resp.json()
else:
self.client.response_exception(resp)