1818
1919import pkg_resources
2020import logging
21+ import google .auth
2122
2223from google .api_core .gapic_v1 import client_info
2324from google .api_core import exceptions
2425from google .cloud .automl_v1beta1 import gapic
2526from google .cloud .automl_v1beta1 .proto import data_types_pb2
27+ from google .cloud .automl_v1beta1 .tables import gcs_client
2628
2729_GAPIC_LIBRARY_VERSION = pkg_resources .get_distribution ("google-cloud-automl" ).version
2830_LOGGER = logging .getLogger (__name__ )
@@ -43,6 +45,7 @@ def __init__(
4345 region = "us-central1" ,
4446 client = None ,
4547 prediction_client = None ,
48+ gcs_client = None ,
4649 ** kwargs
4750 ):
4851 """Constructor.
@@ -118,6 +121,7 @@ def __init__(
118121
119122 self .project = project
120123 self .region = region
124+ self .gcs_client = gcs_client
121125
122126 def __lookup_by_display_name (self , object_type , items , display_name ):
123127 relevant_items = [i for i in items if i .display_name == display_name ]
@@ -403,6 +407,21 @@ def __type_code_to_value_type(self, type_code, value):
403407 else :
404408 raise ValueError ("Unknown type_code: {}" .format (type_code ))
405409
410+ def __ensure_gcs_client_is_initialized (self , credentials = None ):
411+ """Checks if GCS client is initialized. Initializes it if not.
412+
413+ Args:
414+ credentials (google.auth.credentials.Credentials): The
415+ authorization credentials to attach to requests. These
416+ credentials identify this application to the service. If none
417+ are specified, the client will attempt to ascertain the
418+ credentials from the environment.
419+ """
420+ if self .gcs_client is None :
421+ if credentials is None :
422+ credentials , _ = google .auth .default ()
423+ self .gcs_client = gcs_client .GcsClient (credentials = credentials )
424+
406425 def list_datasets (self , project = None , region = None , ** kwargs ):
407426 """List all datasets in a particular project and region.
408427
@@ -642,10 +661,12 @@ def import_data(
642661 dataset = None ,
643662 dataset_display_name = None ,
644663 dataset_name = None ,
664+ pandas_dataframe = None ,
645665 gcs_input_uris = None ,
646666 bigquery_input_uri = None ,
647667 project = None ,
648668 region = None ,
669+ credentials = None ,
649670 ** kwargs
650671 ):
651672 """Imports data into a dataset.
@@ -679,6 +700,11 @@ def import_data(
679700 region (Optional[string]):
680701 If you have initialized the client with a value for `region` it
681702 will be used if this parameter is not supplied.
703+ credentials (Optional[google.auth.credentials.Credentials]): The
704+ authorization credentials to attach to requests. These
705+ credentials identify this application to the service. If none
706+ are specified, the client will attempt to ascertain the
707+ credentials from the environment.
682708 dataset_display_name (Optional[string]):
683709 The human-readable name given to the dataset you want to import
684710 data into. This must be supplied if `dataset` or `dataset_name`
@@ -691,13 +717,21 @@ def import_data(
691717 The `Dataset` instance you want to import data into. This must
692718 be supplied if `dataset_display_name` or `dataset_name` are not
693719 supplied.
720+ pandas_dataframe (Optional[pandas.DataFrame]):
721+ A Pandas Dataframe object containing the data to import. The data
722+ will be converted to CSV, and this CSV will be staged to GCS in
723+ `gs://{project}-automl-tables-staging/{uploaded_csv_name}`
724+ This parameter must be supplied if neither `gcs_input_uris` nor
725+ `bigquery_input_uri` is supplied.
694726 gcs_input_uris (Optional[Union[string, Sequence[string]]]):
695727 Either a single `gs://..` prefixed URI, or a list of URIs
696728 referring to GCS-hosted CSV files containing the data to
697- import. This must be supplied if `bigquery_input_uri` is not.
729+ import. This must be supplied if neither `bigquery_input_uri`
730+ nor `pandas_dataframe` is supplied.
698731 bigquery_input_uri (Optional[string]):
699732 A URI pointing to the BigQuery table containing the data to
700- import. This must be supplied if `gcs_input_uris` is not.
733+ import. This must be supplied if neither `gcs_input_uris` nor
734+ `pandas_dataframe` is supplied.
701735
702736 Returns:
703737 A :class:`~google.cloud.automl_v1beta1.types._OperationFuture`
@@ -720,15 +754,23 @@ def import_data(
720754 )
721755
722756 request = {}
723- if gcs_input_uris is not None :
757+
758+ if pandas_dataframe is not None :
759+ self .__ensure_gcs_client_is_initialized (credentials )
760+ bucket_name = self .gcs_client .ensure_bucket_exists (project , region )
761+ gcs_input_uri = self .gcs_client .upload_pandas_dataframe (
762+ bucket_name , pandas_dataframe
763+ )
764+ request = {"gcs_source" : {"input_uris" : [gcs_input_uri ]}}
765+ elif gcs_input_uris is not None :
724766 if type (gcs_input_uris ) != list :
725767 gcs_input_uris = [gcs_input_uris ]
726768 request = {"gcs_source" : {"input_uris" : gcs_input_uris }}
727769 elif bigquery_input_uri is not None :
728770 request = {"bigquery_source" : {"input_uri" : bigquery_input_uri }}
729771 else :
730772 raise ValueError (
731- "One of 'gcs_input_uris', or " " 'bigquery_input_uri' must be set."
773+ "One of 'gcs_input_uris', or 'bigquery_input_uri', or 'pandas_dataframe ' must be set."
732774 )
733775
734776 op = self .auto_ml_client .import_data (dataset_name , request , ** kwargs )
@@ -2605,6 +2647,7 @@ def predict(
26052647
26062648 def batch_predict (
26072649 self ,
2650+ pandas_dataframe = None ,
26082651 bigquery_input_uri = None ,
26092652 bigquery_output_uri = None ,
26102653 gcs_input_uris = None ,
@@ -2614,6 +2657,7 @@ def batch_predict(
26142657 model_display_name = None ,
26152658 project = None ,
26162659 region = None ,
2660+ credentials = None ,
26172661 inputs = None ,
26182662 ** kwargs
26192663 ):
@@ -2645,15 +2689,30 @@ def batch_predict(
26452689 region (Optional[string]):
26462690 If you have initialized the client with a value for `region` it
26472691 will be used if this parameter is not supplied.
2692+ credentials (Optional[google.auth.credentials.Credentials]): The
2693+ authorization credentials to attach to requests. These
2694+ credentials identify this application to the service. If none
2695+ are specified, the client will attempt to ascertain the
2696+ credentials from the environment.
2697+ pandas_dataframe (Optional[pandas.DataFrame]):
2698+ A Pandas Dataframe object containing the data you want to predict
2699+ off of. The data will be converted to CSV, and this CSV will be
2700+ staged to GCS in `gs://{project}-automl-tables-staging/{uploaded_csv_name}`
2701+ This must be supplied if neither `gcs_input_uris` nor
2702+ `bigquery_input_uri` is supplied.
26482703 gcs_input_uris (Optional(Union[List[string], string]))
26492704 Either a list of or a single GCS URI containing the data you
2650- want to predict off of.
2705+ want to predict off of. This must be supplied if neither
2706+ `pandas_dataframe` nor `bigquery_input_uri` is supplied.
26512707 gcs_output_uri_prefix (Optional[string])
2652- The folder in GCS you want to write output to.
2708+ The folder in GCS you want to write output to. This must be
2709+ supplied if `bigquery_output_uri` is not.
26532710 bigquery_input_uri (Optional[string])
2654- The BigQuery table to input data from.
2711+ The BigQuery table to input data from. This must be supplied if
2712+ neither `pandas_dataframe` nor `gcs_input_uris` is supplied.
26552713 bigquery_output_uri (Optional[string])
2656- The BigQuery table to output data to.
2714+ The BigQuery table to output data to. This must be supplied if
2715+ `gcs_output_uri_prefix` is not.
26572716 model_display_name (Optional[string]):
26582717 The human-readable name given to the model you want to predict
26592718 with. This must be supplied if `model` or `model_name` are not
@@ -2688,7 +2747,15 @@ def batch_predict(
26882747 )
26892748
26902749 input_request = None
2691- if gcs_input_uris is not None :
2750+
2751+ if pandas_dataframe is not None :
2752+ self .__ensure_gcs_client_is_initialized (credentials )
2753+ bucket_name = self .gcs_client .ensure_bucket_exists (project , region )
2754+ gcs_input_uri = self .gcs_client .upload_pandas_dataframe (
2755+ bucket_name , pandas_dataframe
2756+ )
2757+ input_request = {"gcs_source" : {"input_uris" : [gcs_input_uri ]}}
2758+ elif gcs_input_uris is not None :
26922759 if type (gcs_input_uris ) != list :
26932760 gcs_input_uris = [gcs_input_uris ]
26942761 input_request = {"gcs_source" : {"input_uris" : gcs_input_uris }}
0 commit comments