16
16
import os
17
17
import sys
18
18
import threading
19
+ import random
20
+ import uuid
19
21
20
22
from pyspark import SparkContext
21
- from pyspark .sql import Column , DataFrame , SparkSession , SQLContext , functions
22
- from pyspark .sql .functions import *
23
+ from pyspark .sql import Column , DataFrame , SparkSession , SQLContext , functions as F
23
24
from py4j .java_collections import MapConverter
24
25
from delta .tables import *
25
26
from multiprocessing .pool import ThreadPool
26
27
import time
27
28
28
29
"""
29
- create required dynamodb table with:
30
-
31
- $ aws --region us-west-2 dynamodb create-table \
32
- --table-name delta_log_test \
33
- --attribute-definitions AttributeName=tablePath,AttributeType=S \
34
- AttributeName=fileName,AttributeType=S \
35
- --key-schema AttributeName=tablePath,KeyType=HASH \
36
- AttributeName=fileName,KeyType=RANGE \
37
- --provisioned-throughput ReadCapacityUnits=5,WriteCapacityUnits=5
38
-
39
30
run this script in root dir of repository:
40
31
41
32
export VERSION=$(cat version.sbt|cut -d '"' -f 2)
44
35
export DELTA_TABLE_PATH=s3a://test-bucket/delta-test/
45
36
export DELTA_DYNAMO_TABLE=delta_log_test
46
37
export DELTA_DYNAMO_REGION=us-west-2
47
- export DELTA_STORAGE=io.delta.storage.DynamoDBLogStoreScala # TODO: remove `Scala` when Java version finished
38
+ export DELTA_STORAGE=io.delta.storage.DynamoDBLogStore
48
39
export DELTA_NUM_ROWS=16
49
40
50
41
./run-integration-tests.py --run-storage-dynamodb-integration-tests \
59
50
concurrent_readers = int (os .environ .get ("DELTA_CONCURRENT_READERS" , 2 ))
60
51
num_rows = int (os .environ .get ("DELTA_NUM_ROWS" , 16 ))
61
52
62
- # TODO change back to default io.delta.storage.DynamoDBLogStore
63
- delta_storage = os .environ .get ("DELTA_STORAGE" , "io.delta.storage.DynamoDBLogStoreScala" )
53
+ delta_storage = os .environ .get ("DELTA_STORAGE" , "io.delta.storage.DynamoDBLogStore" )
64
54
dynamo_table_name = os .environ .get ("DELTA_DYNAMO_TABLE" , "delta_log_test" )
65
55
dynamo_region = os .environ .get ("DELTA_DYNAMO_REGION" , "us-west-2" )
66
56
dynamo_error_rates = os .environ .get ("DELTA_DYNAMO_ERROR_RATES" , "" )
57
+ table_overwrite = os .environ .get ("DELTA_DYNAMO_TABLE_OVERWRITE" , "true" ).lower () == "true"
67
58
68
59
if delta_table_path is None :
69
60
print (f"\n Skipping Python test { os .path .basename (__file__ )} due to the missing env variable "
90
81
.master ("local[*]" ) \
91
82
.config ("spark.sql.extensions" , "io.delta.sql.DeltaSparkSessionExtension" ) \
92
83
.config ("spark.delta.logStore.class" , delta_storage ) \
93
- .config ("spark.delta.DynamoDBLogStoreScala .tableName" , dynamo_table_name ) \
94
- .config ("spark.delta.DynamoDBLogStoreScala .region" , dynamo_region ) \
95
- .config ("spark.delta.DynamoDBLogStoreScala .errorRates" , dynamo_error_rates ) \
84
+ .config ("spark.delta.DynamoDBLogStore .tableName" , dynamo_table_name ) \
85
+ .config ("spark.delta.DynamoDBLogStore .region" , dynamo_region ) \
86
+ .config ("spark.delta.DynamoDBLogStore .errorRates" , dynamo_error_rates ) \
96
87
.getOrCreate ()
97
88
98
- data = spark .createDataFrame ([], "id: int, a: int" )
99
- data .write .format ("delta" ).mode ("overwrite" ).partitionBy ("id" ).save (delta_table_path )
89
+ SCHEMA = "run_id: string, id: int, a: int"
90
+
91
+ RUN_ID = str (uuid .uuid4 ())
92
+
93
+ data = spark .createDataFrame ([], SCHEMA )
94
+
95
+ if table_overwrite :
96
+ data .write .format ("delta" ).mode ("overwrite" ).partitionBy ("run_id" , "id" ).save (delta_table_path )
97
+
100
98
101
99
def write_tx (n ):
102
- data = spark .createDataFrame ([[n , n ]], "id: int, a: int" )
103
- data .write .format ("delta" ).mode ("append" ).partitionBy ("id" ).save (delta_table_path )
100
+ data = spark .createDataFrame ([[RUN_ID , random .randrange (2 ** 16 ), n ]], SCHEMA )
101
+ data .write .format ("delta" ).mode ("append" ).partitionBy ("run_id" , "id" ).save (delta_table_path )
102
+
103
+
104
+ def count ():
105
+ return (
106
+ spark .read .format ("delta" )
107
+ .load (delta_table_path )
108
+ .filter (F .col ("run_id" ) == RUN_ID )
109
+ .count ()
110
+ )
104
111
105
112
106
113
stop_reading = threading .Event ()
107
114
108
115
def read_data ():
109
116
while not stop_reading .is_set ():
110
- print ("Reading {:d} rows ..." .format (spark .read .format ("delta" ).load (delta_table_path ).distinct ().count ()))
117
+ cnt = count ()
118
+ print (f"Reading { cnt } rows ..." )
111
119
time .sleep (1 )
112
120
113
121
@@ -127,7 +135,7 @@ def start_read_thread():
127
135
for thread in read_threads :
128
136
thread .join ()
129
137
130
- actual = spark . read . format ( "delta" ). load ( delta_table_path ). distinct (). count ()
138
+ actual = count ()
131
139
print ("Number of written rows:" , actual )
132
140
assert actual == num_rows
133
141
0 commit comments