Skip to content

Commit b9c0892

Browse files
TrucHLeemar-kar
authored andcommitted
Enable users to pass in Pandas Dataframe when calling import_data() and batch_predict() from AutoML Tables client (googleapis#9116)
1 parent eb92473 commit b9c0892

9 files changed

Lines changed: 382 additions & 23 deletions

File tree

automl/google/cloud/automl_v1beta1/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from google.cloud.automl_v1beta1.gapic import auto_ml_client
2222
from google.cloud.automl_v1beta1.gapic import enums
2323
from google.cloud.automl_v1beta1.gapic import prediction_service_client
24+
from google.cloud.automl_v1beta1.tables import gcs_client
2425
from google.cloud.automl_v1beta1.tables import tables_client
2526

2627

@@ -38,4 +39,15 @@ class PredictionServiceClient(prediction_service_client.PredictionServiceClient)
3839
enums = enums
3940

4041

41-
__all__ = ("enums", "types", "AutoMlClient", "PredictionServiceClient", "TablesClient")
42+
class GcsClient(gcs_client.GcsClient):
43+
__doc__ = gcs_client.GcsClient.__doc__
44+
45+
46+
__all__ = (
47+
"enums",
48+
"types",
49+
"AutoMlClient",
50+
"PredictionServiceClient",
51+
"TablesClient",
52+
"GcsClient",
53+
)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# Copyright 2019 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Wraps the Google Cloud Storage client library for use in tables helper."""
18+
19+
import time
20+
21+
from google.api_core import exceptions
22+
23+
try:
24+
import pandas
25+
except ImportError: # pragma: NO COVER
26+
pandas = None
27+
28+
try:
29+
from google.cloud import storage
30+
except ImportError: # pragma: NO COVER
31+
storage = None
32+
33+
_PANDAS_REQUIRED = "pandas is required to verify type DataFrame."
34+
_STORAGE_REQUIRED = (
35+
"google-cloud-storage is required to create Google Cloud Storage client."
36+
)
37+
38+
39+
class GcsClient(object):
40+
"""Uploads Pandas DataFrame to a bucket in Google Cloud Storage."""
41+
42+
def __init__(self, client=None, credentials=None):
43+
"""Constructor.
44+
45+
Args:
46+
client (Optional[storage.Client]): A Google Cloud Storage Client
47+
instance.
48+
credentials (Optional[google.auth.credentials.Credentials]): The
49+
authorization credentials to attach to requests. These
50+
credentials identify this application to the service. If none
51+
are specified, the client will attempt to ascertain the
52+
credentials from the environment.
53+
"""
54+
if storage is None:
55+
raise ImportError(_STORAGE_REQUIRED)
56+
57+
if client is not None:
58+
self.client = client
59+
elif credentials is not None:
60+
self.client = storage.Client(credentials=credentials)
61+
else:
62+
self.client = storage.Client()
63+
64+
def ensure_bucket_exists(self, project, region):
65+
"""Checks if a bucket named '{project}-automl-tables-staging' exists.
66+
67+
Creates this bucket if it doesn't exist.
68+
69+
Args:
70+
project (str): The project that stores the bucket.
71+
region (str): The region of the bucket.
72+
73+
Returns:
74+
A string representing the created bucket name.
75+
"""
76+
bucket_name = "{}-automl-tables-staging".format(project)
77+
78+
try:
79+
self.client.get_bucket(bucket_name)
80+
except exceptions.NotFound:
81+
bucket = self.client.bucket(bucket_name)
82+
bucket.create(project=project, location=region)
83+
return bucket_name
84+
85+
def upload_pandas_dataframe(self, bucket_name, dataframe, uploaded_csv_name=None):
86+
"""Uploads a Pandas DataFrame as CSV to the bucket.
87+
88+
Args:
89+
bucket_name (str): The bucket name to upload the CSV to.
90+
dataframe (pandas.DataFrame): The Pandas Dataframe to be uploaded.
91+
uploaded_csv_name (Optional[str]): The name for the uploaded CSV.
92+
93+
Returns:
94+
A string representing the GCS URI of the uploaded CSV.
95+
"""
96+
if pandas is None:
97+
raise ImportError(_PANDAS_REQUIRED)
98+
99+
if not isinstance(dataframe, pandas.DataFrame):
100+
raise ValueError("'dataframe' must be a pandas.DataFrame instance.")
101+
102+
if uploaded_csv_name is None:
103+
uploaded_csv_name = "automl-tables-dataframe-{}.csv".format(
104+
int(time.time())
105+
)
106+
csv_string = dataframe.to_csv()
107+
108+
bucket = self.client.get_bucket(bucket_name)
109+
blob = bucket.blob(uploaded_csv_name)
110+
blob.upload_from_string(csv_string)
111+
112+
return "gs://{}/{}".format(bucket_name, uploaded_csv_name)

automl/google/cloud/automl_v1beta1/tables/tables_client.py

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818

1919
import pkg_resources
2020
import logging
21+
import google.auth
2122

2223
from google.api_core.gapic_v1 import client_info
2324
from google.api_core import exceptions
2425
from google.cloud.automl_v1beta1 import gapic
2526
from google.cloud.automl_v1beta1.proto import data_types_pb2
27+
from google.cloud.automl_v1beta1.tables import gcs_client
2628

2729
_GAPIC_LIBRARY_VERSION = pkg_resources.get_distribution("google-cloud-automl").version
2830
_LOGGER = logging.getLogger(__name__)
@@ -43,6 +45,7 @@ def __init__(
4345
region="us-central1",
4446
client=None,
4547
prediction_client=None,
48+
gcs_client=None,
4649
**kwargs
4750
):
4851
"""Constructor.
@@ -118,6 +121,7 @@ def __init__(
118121

119122
self.project = project
120123
self.region = region
124+
self.gcs_client = gcs_client
121125

122126
def __lookup_by_display_name(self, object_type, items, display_name):
123127
relevant_items = [i for i in items if i.display_name == display_name]
@@ -403,6 +407,21 @@ def __type_code_to_value_type(self, type_code, value):
403407
else:
404408
raise ValueError("Unknown type_code: {}".format(type_code))
405409

410+
def __ensure_gcs_client_is_initialized(self, credentials=None):
411+
"""Checks if GCS client is initialized. Initializes it if not.
412+
413+
Args:
414+
credentials (google.auth.credentials.Credentials): The
415+
authorization credentials to attach to requests. These
416+
credentials identify this application to the service. If none
417+
are specified, the client will attempt to ascertain the
418+
credentials from the environment.
419+
"""
420+
if self.gcs_client is None:
421+
if credentials is None:
422+
credentials, _ = google.auth.default()
423+
self.gcs_client = gcs_client.GcsClient(credentials=credentials)
424+
406425
def list_datasets(self, project=None, region=None, **kwargs):
407426
"""List all datasets in a particular project and region.
408427
@@ -642,10 +661,12 @@ def import_data(
642661
dataset=None,
643662
dataset_display_name=None,
644663
dataset_name=None,
664+
pandas_dataframe=None,
645665
gcs_input_uris=None,
646666
bigquery_input_uri=None,
647667
project=None,
648668
region=None,
669+
credentials=None,
649670
**kwargs
650671
):
651672
"""Imports data into a dataset.
@@ -679,6 +700,11 @@ def import_data(
679700
region (Optional[string]):
680701
If you have initialized the client with a value for `region` it
681702
will be used if this parameter is not supplied.
703+
credentials (Optional[google.auth.credentials.Credentials]): The
704+
authorization credentials to attach to requests. These
705+
credentials identify this application to the service. If none
706+
are specified, the client will attempt to ascertain the
707+
credentials from the environment.
682708
dataset_display_name (Optional[string]):
683709
The human-readable name given to the dataset you want to import
684710
data into. This must be supplied if `dataset` or `dataset_name`
@@ -691,13 +717,21 @@ def import_data(
691717
The `Dataset` instance you want to import data into. This must
692718
be supplied if `dataset_display_name` or `dataset_name` are not
693719
supplied.
720+
pandas_dataframe (Optional[pandas.DataFrame]):
721+
A Pandas Dataframe object containing the data to import. The data
722+
will be converted to CSV, and this CSV will be staged to GCS in
723+
`gs://{project}-automl-tables-staging/{uploaded_csv_name}`
724+
This parameter must be supplied if neither `gcs_input_uris` nor
725+
`bigquery_input_uri` is supplied.
694726
gcs_input_uris (Optional[Union[string, Sequence[string]]]):
695727
Either a single `gs://..` prefixed URI, or a list of URIs
696728
referring to GCS-hosted CSV files containing the data to
697-
import. This must be supplied if `bigquery_input_uri` is not.
729+
import. This must be supplied if neither `bigquery_input_uri`
730+
nor `pandas_dataframe` is supplied.
698731
bigquery_input_uri (Optional[string]):
699732
A URI pointing to the BigQuery table containing the data to
700-
import. This must be supplied if `gcs_input_uris` is not.
733+
import. This must be supplied if neither `gcs_input_uris` nor
734+
`pandas_dataframe` is supplied.
701735
702736
Returns:
703737
A :class:`~google.cloud.automl_v1beta1.types._OperationFuture`
@@ -720,15 +754,23 @@ def import_data(
720754
)
721755

722756
request = {}
723-
if gcs_input_uris is not None:
757+
758+
if pandas_dataframe is not None:
759+
self.__ensure_gcs_client_is_initialized(credentials)
760+
bucket_name = self.gcs_client.ensure_bucket_exists(project, region)
761+
gcs_input_uri = self.gcs_client.upload_pandas_dataframe(
762+
bucket_name, pandas_dataframe
763+
)
764+
request = {"gcs_source": {"input_uris": [gcs_input_uri]}}
765+
elif gcs_input_uris is not None:
724766
if type(gcs_input_uris) != list:
725767
gcs_input_uris = [gcs_input_uris]
726768
request = {"gcs_source": {"input_uris": gcs_input_uris}}
727769
elif bigquery_input_uri is not None:
728770
request = {"bigquery_source": {"input_uri": bigquery_input_uri}}
729771
else:
730772
raise ValueError(
731-
"One of 'gcs_input_uris', or " "'bigquery_input_uri' must be set."
773+
"One of 'gcs_input_uris', or 'bigquery_input_uri', or 'pandas_dataframe' must be set."
732774
)
733775

734776
op = self.auto_ml_client.import_data(dataset_name, request, **kwargs)
@@ -2605,6 +2647,7 @@ def predict(
26052647

26062648
def batch_predict(
26072649
self,
2650+
pandas_dataframe=None,
26082651
bigquery_input_uri=None,
26092652
bigquery_output_uri=None,
26102653
gcs_input_uris=None,
@@ -2614,6 +2657,7 @@ def batch_predict(
26142657
model_display_name=None,
26152658
project=None,
26162659
region=None,
2660+
credentials=None,
26172661
inputs=None,
26182662
**kwargs
26192663
):
@@ -2645,15 +2689,30 @@ def batch_predict(
26452689
region (Optional[string]):
26462690
If you have initialized the client with a value for `region` it
26472691
will be used if this parameter is not supplied.
2692+
credentials (Optional[google.auth.credentials.Credentials]): The
2693+
authorization credentials to attach to requests. These
2694+
credentials identify this application to the service. If none
2695+
are specified, the client will attempt to ascertain the
2696+
credentials from the environment.
2697+
pandas_dataframe (Optional[pandas.DataFrame]):
2698+
A Pandas Dataframe object containing the data you want to predict
2699+
off of. The data will be converted to CSV, and this CSV will be
2700+
staged to GCS in `gs://{project}-automl-tables-staging/{uploaded_csv_name}`
2701+
This must be supplied if neither `gcs_input_uris` nor
2702+
`bigquery_input_uri` is supplied.
26482703
gcs_input_uris (Optional(Union[List[string], string]))
26492704
Either a list of or a single GCS URI containing the data you
2650-
want to predict off of.
2705+
want to predict off of. This must be supplied if neither
2706+
`pandas_dataframe` nor `bigquery_input_uri` is supplied.
26512707
gcs_output_uri_prefix (Optional[string])
2652-
The folder in GCS you want to write output to.
2708+
The folder in GCS you want to write output to. This must be
2709+
supplied if `bigquery_output_uri` is not.
26532710
bigquery_input_uri (Optional[string])
2654-
The BigQuery table to input data from.
2711+
The BigQuery table to input data from. This must be supplied if
2712+
neither `pandas_dataframe` nor `gcs_input_uris` is supplied.
26552713
bigquery_output_uri (Optional[string])
2656-
The BigQuery table to output data to.
2714+
The BigQuery table to output data to. This must be supplied if
2715+
`gcs_output_uri_prefix` is not.
26572716
model_display_name (Optional[string]):
26582717
The human-readable name given to the model you want to predict
26592718
with. This must be supplied if `model` or `model_name` are not
@@ -2688,7 +2747,15 @@ def batch_predict(
26882747
)
26892748

26902749
input_request = None
2691-
if gcs_input_uris is not None:
2750+
2751+
if pandas_dataframe is not None:
2752+
self.__ensure_gcs_client_is_initialized(credentials)
2753+
bucket_name = self.gcs_client.ensure_bucket_exists(project, region)
2754+
gcs_input_uri = self.gcs_client.upload_pandas_dataframe(
2755+
bucket_name, pandas_dataframe
2756+
)
2757+
input_request = {"gcs_source": {"input_uris": [gcs_input_uri]}}
2758+
elif gcs_input_uris is not None:
26922759
if type(gcs_input_uris) != list:
26932760
gcs_input_uris = [gcs_input_uris]
26942761
input_request = {"gcs_source": {"input_uris": gcs_input_uris}}

automl/noxfile.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def default(session):
7070
for local_dep in LOCAL_DEPS:
7171
session.install("-e", local_dep)
7272
session.install("-e", ".")
73+
session.install("-e", ".[pandas,storage]")
7374

7475
# Run py.test against the unit tests.
7576
session.run(
@@ -117,6 +118,7 @@ def system(session):
117118
session.install("-e", local_dep)
118119
session.install("-e", "../test_utils/")
119120
session.install("-e", ".")
121+
session.install("-e", ".[pandas,storage]")
120122

121123
# Run py.test against the system tests.
122124
if system_test_exists:

automl/setup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
"google-api-core[grpc] >= 1.14.0, < 2.0.0dev",
2626
'enum34; python_version < "3.4"',
2727
]
28+
extras = {
29+
"pandas": ["pandas>=0.24.0"],
30+
"storage": ["google-cloud-storage >= 1.18.0, < 2.0.0dev"],
31+
}
2832

2933
package_root = os.path.abspath(os.path.dirname(__file__))
3034

@@ -67,6 +71,7 @@
6771
packages=packages,
6872
namespace_packages=namespaces,
6973
install_requires=dependencies,
74+
extras_require=extras,
7075
python_requires=">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*",
7176
include_package_data=True,
7277
zip_safe=False,

0 commit comments

Comments
 (0)