Skip to content

Commit

Permalink
Setting up Glue 4.0
Browse files Browse the repository at this point in the history
  • Loading branch information
neilagupta committed Dec 9, 2022
1 parent c513156 commit f973095
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 37 deletions.
2 changes: 1 addition & 1 deletion NOTICE.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
aws-glue-libs
Copyright 2016-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Copyright 2016-2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.

15 changes: 9 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ This repository contains:

Different Glue versions support different Python versions. The following table below is for your reference, which also includes the associated repository's branch for each glue version.

| Glue Version | Python 2 Version | Python 3 Version | aws-glue-libs branch|
|---|---|---| --- |
| 0.9 | 2.7 | Not supported | glue-0.9 |
| 1.0 | 2.7 | 3.6 | glue-1.0 |
| 2.0 | Not supported | 3.7 | glue-2.0 |
| 3.0 | Not supported | 3.7 | master |
| Glue Version | Python 2 Version | Python 3 Version | aws-glue-libs branch |
|---|---|---|----------------------|
| 0.9 | 2.7 | Not supported | glue-0.9 |
| 1.0 | 2.7 | 3.6 | glue-1.0 |
| 2.0 | Not supported | 3.7 | glue-2.0 |
| 3.0 | Not supported | 3.7 | glue-3.0 |
| 4.0 | Not supported | 3.10 | master |

You may refer to AWS Glue's official [release notes](https://docs.aws.amazon.com/glue/latest/dg/release-notes.html) for more information

Expand All @@ -33,12 +34,14 @@ The `awsglue` library provides only the Python interface to the Glue Spark runti
* Glue version 1.0: `https://aws-glue-etl-artifacts.s3.amazonaws.com/glue-1.0/spark-2.4.3-bin-hadoop2.8.tgz1`
* Glue version 2.0: `https://aws-glue-etl-artifacts.s3.amazonaws.com/glue-2.0/spark-2.4.3-bin-hadoop2.8.tgz1`
* Glue version 3.0: `https://aws-glue-etl-artifacts.s3.amazonaws.com/glue-3.0/spark-3.1.1-amzn-0-bin-3.2.1-amzn-3.tgz`
* Glue version 4.0: `https://aws-glue-etl-artifacts.s3.amazonaws.com/glue-4.0/spark-3.3.0-amzn-1-bin-3.3.3-amzn-0.tgz`
1. export the `SPARK_HOME` environmental variable to the extracted location of the above Spark distribution. For example:
```
Glue version 0.9: export SPARK_HOME=/home/$USER/spark-2.2.1-bin-hadoop2.7
Glue version 1.0: export SPARK_HOME=/home/$USER/spark-2.4.3-bin-hadoop2.8
Glue version 2.0: export SPARK_HOME=/home/$USER/spark-2.4.3-bin-hadoop2.8
Glue version 3.0: export SPARK_HOME=/home/$USER/spark-3.1.1-amzn-0-bin-3.2.1-amzn-3
Glue version 4.0: export SPARK_HOME=/home/$USER/spark-3.3.0-amzn-1-bin-3.3.3-amzn-0
```
1. now you can run the executables in the `bin` directory to start a Glue Shell or submit a Glue Spark application.
```
Expand Down
163 changes: 138 additions & 25 deletions awsglue/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from awsglue.streaming_data_source import StreamingDataSource
from awsglue.data_sink import DataSink
from awsglue.dataframereader import DataFrameReader
from awsglue.dataframewriter import DataFrameWriter
from awsglue.dynamicframe import DynamicFrame, DynamicFrameReader, DynamicFrameWriter, DynamicFrameCollection
from awsglue.gluetypes import DataType
from awsglue.utils import makeOptions, callsite
Expand All @@ -41,13 +42,16 @@ def register(sc):
java_import(sc._jvm, "com.amazonaws.services.glue.util.AWSConnectionUtils")
java_import(sc._jvm, "com.amazonaws.services.glue.util.GluePythonUtils")
java_import(sc._jvm, "com.amazonaws.services.glue.errors.CallSite")
java_import(sc._jvm, "com.amazonaws.services.glue.ml.EntityDetector")
java_import(sc._jvm, "com.amazonaws.services.glue.dq.EvaluateDataQuality")
# java_import(sc._jvm, "com.amazonaws.services.glue.ml.FindMatches")
# java_import(sc._jvm, "com.amazonaws.services.glue.ml.FindIncrementalMatches")
# java_import(sc._jvm, "com.amazonaws.services.glue.ml.FillMissingValues")


class GlueContext(SQLContext):
Spark_SQL_Formats = {"parquet", "orc"}
Unsupported_Compression_Types = {"lzo"}

def __init__(self, sparkContext, **options):
super(GlueContext, self).__init__(sparkContext)
Expand All @@ -56,6 +60,7 @@ def __init__(self, sparkContext, **options):
self.create_dynamic_frame = DynamicFrameReader(self)
self.create_data_frame = DataFrameReader(self)
self.write_dynamic_frame = DynamicFrameWriter(self)
self.write_data_frame = DataFrameWriter(self)
self.spark_session = SparkSession(sparkContext, self._glue_scala_context.getSparkSession())
self._glue_logger = sparkContext._jvm.GlueLogger()

Expand Down Expand Up @@ -89,7 +94,11 @@ def getSource(self, connection_type, format = None, transformation_ctx = "", pus
>>> myFrame = data_source.getFrame()
"""
options["callSite"] = callsite()
if(format and format.lower() in self.Spark_SQL_Formats):
compressionType = options.get("compressionType", "")
if compressionType in self.Unsupported_Compression_Types and format == None:
raise Exception("When using compressionType {}, the format parameter must be specified.".format(compressionType))
#if get unsupported compression type, fallback to use spark sql datasource.
if((format and format.lower() in self.Spark_SQL_Formats) or (compressionType in self.Unsupported_Compression_Types)):
connection_type = format

j_source = self._ssql_ctx.getSource(connection_type,
Expand Down Expand Up @@ -222,11 +231,56 @@ def create_dynamic_frame_from_options(self, connection_type, connection_options=
"""
source = self.getSource(connection_type, format, transformation_ctx, push_down_predicate, **connection_options)

if (format and format not in self.Spark_SQL_Formats):
if (format and format not in self.Spark_SQL_Formats and connection_options.get("compressionType", "") not in self.Unsupported_Compression_Types):
source.setFormat(format, **format_options)

return source.getFrame(**kwargs)

def create_sample_dynamic_frame_from_catalog(self, database = None, table_name = None, num = None, sample_options = {}, redshift_tmp_dir = "",
transformation_ctx = "", push_down_predicate="", additional_options = {},
catalog_id = None, erieTxId = "", asOfTime = "", **kwargs):
"""
return a list of sample dynamic records with catalog database, table name and an optional catalog id
:param database: database in catalog
:param table_name: table name
:param num: number of sample records
:param sample_options: options for sampling behavior
:param transformation_ctx: transformation context
:param push_down_predicate
:param additional_options
:param catalog_id catalog id of the DataCatalog being accessed (account id of the data catalog).
Set to None by default (None defaults to the catalog id of the calling account in the service)
:return: dynamic frame with potential errors
"""
if database is not None and "name_space" in kwargs:
raise Exception("Parameter name_space and database are both specified, choose one.")
elif database is None and "name_space" not in kwargs:
raise Exception("Parameter name_space or database is missing.")
elif "name_space" in kwargs:
db = kwargs.pop("name_space")
else:
db = database

if table_name is None:
raise Exception("Parameter table_name is missing.")
source = DataSource(self._ssql_ctx.getCatalogSource(db, table_name, redshift_tmp_dir, transformation_ctx,
push_down_predicate,
makeOptions(self._sc, additional_options), catalog_id),
self, table_name)
return source.getSampleFrame(num, **sample_options)

def create_sample_dynamic_frame_from_options(self, connection_type, connection_options={}, num = None, sample_options = {},
format=None, format_options={}, transformation_ctx = "", push_down_predicate= "", **kwargs):
"""Creates a list of sample dynamic records with the specified connection and format.
"""
source = self.getSource(connection_type, format, transformation_ctx, push_down_predicate, **connection_options)

if (format and format not in self.Spark_SQL_Formats):
source.setFormat(format, **format_options)

return source.getSampleFrame(num, **sample_options)


def create_data_frame_from_options(self, connection_type, connection_options={},
format=None, format_options={}, transformation_ctx = "", push_down_predicate= "", **kwargs):
"""Creates a DataFrame with the specified connection and format. Used for streaming data sources
Expand Down Expand Up @@ -332,6 +386,24 @@ def write_dynamic_frame_from_catalog(self, frame, database = None, table_name =
makeOptions(self._sc, additional_options), catalog_id)
return DataSink(j_sink, self).write(frame)

def write_data_frame_from_catalog(self, frame, database = None, table_name = None, redshift_tmp_dir = "",
transformation_ctx = "", additional_options = {}, catalog_id = None, **kwargs):
if database is not None and "name_space" in kwargs:
raise Exception("Parameter name_space and database are both specified, choose one.")
elif database is None and "name_space" not in kwargs:
raise Exception("Parameter name_space or database is missing.")
elif "name_space" in kwargs:
db = kwargs.pop("name_space")
else:
db = database

if table_name is None:
raise Exception("Parameter table_name is missing.")

j_sink = self._ssql_ctx.getCatalogSink(db, table_name, redshift_tmp_dir, transformation_ctx,
makeOptions(self._sc, additional_options), catalog_id)
return DataSink(j_sink, self).writeDataFrame(frame, self)

def write_dynamic_frame_from_jdbc_conf(self, frame, catalog_connection, connection_options={},
redshift_tmp_dir = "", transformation_ctx = "", catalog_id = None):
"""
Expand Down Expand Up @@ -478,6 +550,55 @@ def get_logger(self):
def currentTimeMillis(self):
return int(round(time.time() * 1000))

def getSampleStreamingDynamicFrame(self, frame, options={}, batch_function=None):
if "windowSize" not in options:
raise ValueError("Missing windowSize argument")

windowSize = options["windowSize"]
pollingTimeInMs = int(options.get("pollingTimeInMs", 10000))
recordPollingLimit = int(options.get("recordPollingLimit", 100))

# Use a different implementation here due to Py4J limitation
def convert_window_size_to_milis(window_size):
if type(window_size) != str or " " not in window_size.strip():
raise ValueError("Received invalid window size")
chunks = window_size.strip().split(" ")
if len(chunks) != 2:
raise ValueError("Received invalid window size")
unit = chunks[1].lower()
if "second" in unit:
multiplier = 1000
elif "minute" in unit:
multiplier = 1000 * 60
elif "hour" in unit:
multiplier = 1000 * 60 * 60
else:
raise ValueError("Received invalid window size")
try:
quantity = int(chunks[0])
except:
raise ValueError("Received invalid window size")
return quantity * multiplier

windowSizeInMilis = convert_window_size_to_milis(windowSize)
if windowSizeInMilis >= pollingTimeInMs:
raise ValueError("Polling time needs to be larger than window size")

tableId = str(uuid.uuid4()).replace("-", "")
writer = frame.writeStream\
.trigger(processingTime=windowSize)\
.queryName(tableId)\
.format("memory")
if batch_function is not None:
writer = writer.foreachBatch(batch_function)

query = writer.start()
resultDF = self.spark_session.sql("select * from " + tableId + " limit " + str(recordPollingLimit))
time.sleep(pollingTimeInMs / 1000)
query.stop()
return DynamicFrame.fromDF(resultDF, self, tableId)


def forEachBatch(self, frame, batch_function, options = {}):
if "windowSize" not in options:
raise Exception("Missing windowSize argument")
Expand All @@ -487,21 +608,8 @@ def forEachBatch(self, frame, batch_function, options = {}):
windowSize = options["windowSize"]
checkpointLocation = options["checkpointLocation"]

# Check the Glue version
glue_ver = self.getConf('spark.glue.GLUE_VERSION', '')
java_import(self._jvm, "org.apache.spark.metrics.source.StreamingSource")

# Converting the S3 scheme to S3a for the Glue Streaming checkpoint location in connector jars.
# S3 scheme on checkpointLocation currently doesn't work on Glue 2.0 (non-EMR).
# Will remove this once the connector package is imported as brazil package.
if (glue_ver == '2.0' or glue_ver == '2' or glue_ver == '3.0' or glue_ver == '3'):
if (checkpointLocation.startswith( 's3://' )):
java_import(self._jvm, "com.amazonaws.regions.RegionUtils")
java_import(self._jvm, "com.amazonaws.services.s3.AmazonS3")
self._jsc.hadoopConfiguration().set("fs.s3a.endpoint", self._jvm.RegionUtils.getRegion(
self._jvm.AWSConnectionUtils.getRegion()).getServiceEndpoint(self._jvm.AmazonS3.ENDPOINT_PREFIX))
checkpointLocation = checkpointLocation.replace( 's3://', 's3a://', 1)

run = {'value': 0}
retry_attempt = {'value': 0}

Expand All @@ -512,7 +620,7 @@ def batch_function_with_persist(data_frame, batchId):
run['value'] = 0
if retry_attempt['value'] > 0:
retry_attempt['value'] = 0
logging.warning("The batch is now succeeded. Resetting retry attempt counter to zero.")
logging.info("The previous batch was succeeded. Reset the retry attempt counter to 0.")
run['value'] += 1

# process the batch
Expand All @@ -538,15 +646,20 @@ def batch_function_with_persist(data_frame, batchId):

while (True):
try:
if retry_attempt['value'] > 0:
logging.warning("Retrying micro batch processing, attempt {} out of {}. ".format(retry_attempt['value'], batch_max_retries))
query.start().awaitTermination()
except Exception as e:

if str(e).startswith("CheckpointMetadataNotFound"):
raise e

retry_attempt['value'] += 1
logging.warning("StreamingQueryException caught. Retry number " + str(retry_attempt['value']))

if retry_attempt['value'] > batch_max_retries:
logging.error("Exceeded maximuim number of retries in streaming interval, exception thrown")
self._glue_logger.error("Exceeded the maximum number of batch retries. Throwing the exception. ")
raise e
# lastFailedAttempt = failedTime

backOffTime = retry_attempt['value'] if (retry_attempt['value'] < 3) else 5
time.sleep(backOffTime)

Expand All @@ -560,11 +673,11 @@ def batch_function_with_persist(data_frame, batchId):
def add_ingestion_time_columns(self, frame, time_granularity):
return DataFrame(self._ssql_ctx.addIngestionTimeColumns(frame._jdf, time_granularity), frame.sql_ctx)

def begin_transaction(self, read_only):
return self._ssql_ctx.beginTransaction(read_only)
def start_transaction(self, read_only):
return self._ssql_ctx.startTransaction(read_only)

def commit_transaction(self, transaction_id):
return self._ssql_ctx.commitTransaction(transaction_id)
def commit_transaction(self, transaction_id, wait_for_commit=True):
return self._ssql_ctx.commitTransaction(transaction_id, wait_for_commit)

def abort_transaction(self, transaction_id):
return self._ssql_ctx.abortTransaction(transaction_id)
def cancel_transaction(self, transaction_id):
return self._ssql_ctx.cancelTransaction(transaction_id)
4 changes: 4 additions & 0 deletions awsglue/data_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from awsglue.dynamicframe import DynamicFrame, DynamicFrameCollection
from awsglue.utils import makeOptions, callsite
from pyspark.sql import DataFrame

class DataSink(object):
def __init__(self, j_sink, sql_ctx):
Expand All @@ -30,6 +31,9 @@ def setCatalogInfo(self, catalogDatabase, catalogTableName, catalogId = ""):
def writeFrame(self, dynamic_frame, info = ""):
return DynamicFrame(self._jsink.pyWriteDynamicFrame(dynamic_frame._jdf, callsite(), info), dynamic_frame.glue_ctx, dynamic_frame.name + "_errors")

def writeDataFrame(self, data_frame, glue_context, info = ""):
return DataFrame(self._jsink.pyWriteDataFrame(data_frame._jdf, glue_context._glue_scala_context, callsite(), info), self._sql_ctx)

def write(self, dynamic_frame_or_dfc, info = ""):
if isinstance(dynamic_frame_or_dfc, DynamicFrame):
return self.writeFrame(dynamic_frame_or_dfc, info)
Expand Down
4 changes: 4 additions & 0 deletions awsglue/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,7 @@ def getFrame(self, **options):
jframe = self._jsource.getDynamicFrame(minPartitions, targetPartitions)

return DynamicFrame(jframe, self._sql_ctx, self.name)

def getSampleFrame(self, num, **options):
jframe = self._jsource.getSampleDynamicFrame(num, makeOptions(self._sql_ctx._sc, options))
return DynamicFrame(jframe, self._sql_ctx, self.name)
21 changes: 21 additions & 0 deletions awsglue/dataframewriter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
class DataFrameWriter(object):
def __init__(self, glue_context):
self._glue_context = glue_context
def from_catalog(self, frame, database=None, table_name=None, redshift_tmp_dir="", transformation_ctx="",
additional_options={}, catalog_id=None, **kwargs):
"""Writes a DataFrame with the specified catalog name space and table name.
"""
if database is not None and "name_space" in kwargs:
raise Exception("Parameter name_space and database are both specified, choose one.")
elif database is None and "name_space" not in kwargs:
raise Exception("Parameter name_space or database is missing.")
elif "name_space" in kwargs:
db = kwargs.pop("name_space")
else:
db = database

if table_name is None:
raise Exception("Parameter table_name is missing.")

return self._glue_context.write_data_frame_from_catalog(frame, db, table_name, redshift_tmp_dir,
transformation_ctx, additional_options, catalog_id)
10 changes: 7 additions & 3 deletions awsglue/dynamicframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def schema(self):
raise Exception("Unable to parse datatype from schema. %s" % e)
return self._schema

def show(self, num_rows = 20):
self._jdf.show(num_rows)
def show(self, num_rows=20):
print(self._jdf.showString(num_rows))

def filter(self, f, transformation_ctx = "", info="", stageThreshold=0, totalThreshold=0):
def filter(self, f, transformation_ctx="", info="", stageThreshold=0, totalThreshold=0):
def wrap_dict_with_dynamic_records(x):
rec = _create_dynamic_record(x["record"])
try:
Expand Down Expand Up @@ -387,6 +387,10 @@ def _to_java_mapping(mapping_tup):

return DynamicFrame(new_jdf, self.glue_ctx, self.name)

def unnest_ddb_json(self, transformation_ctx="", info="", stageThreshold=0, totalThreshold=0):
new_jdf = self._jdf.unnestDDBJson(transformation_ctx, _call_site(self._sc, callsite(), info), long(stageThreshold), long(totalThreshold))
return DynamicFrame(new_jdf, self.glue_ctx, self.name)

def resolveChoice(self, specs=None, choice="", database=None, table_name=None,
transformation_ctx="", info="", stageThreshold=0, totalThreshold=0, catalog_id=None):
"""
Expand Down
Loading

0 comments on commit f973095

Please sign in to comment.