Skip to content

Commit 2c44fdd

Browse files
authored
Removed unnecessary aws_conn_id param from operators constructors (apache#51236)
* Removed unnecessary aws_conn_id param from operators constructors * Added regression tests to operators and renamed no_conn test to default_conn
1 parent aadcf45 commit 2c44fdd

File tree

9 files changed

+147
-14
lines changed

9 files changed

+147
-14
lines changed

providers/amazon/src/airflow/providers/amazon/aws/operators/cloud_formation.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,11 @@ def __init__(
9898
*,
9999
stack_name: str,
100100
cloudformation_parameters: dict | None = None,
101-
aws_conn_id: str | None = "aws_default",
102101
**kwargs,
103102
):
104103
super().__init__(**kwargs)
105104
self.cloudformation_parameters = cloudformation_parameters or {}
106105
self.stack_name = stack_name
107-
self.aws_conn_id = aws_conn_id
108106

109107
def execute(self, context: Context):
110108
self.log.info("CloudFormation Parameters: %s", self.cloudformation_parameters)

providers/amazon/src/airflow/providers/amazon/aws/operators/comprehend.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,6 @@ def __init__(
289289
waiter_delay: int = 60,
290290
waiter_max_attempts: int = 20,
291291
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
292-
aws_conn_id: str | None = "aws_default",
293292
**kwargs,
294293
):
295294
super().__init__(**kwargs)
@@ -305,7 +304,6 @@ def __init__(
305304
self.waiter_delay = waiter_delay
306305
self.waiter_max_attempts = waiter_max_attempts
307306
self.deferrable = deferrable
308-
self.aws_conn_id = aws_conn_id
309307

310308
def execute(self, context: Context) -> str:
311309
if self.output_data_config:

providers/amazon/src/airflow/providers/amazon/aws/operators/dms.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def __init__(
9191
table_mappings: dict,
9292
migration_type: str = "full-load",
9393
create_task_kwargs: dict | None = None,
94-
aws_conn_id: str | None = "aws_default",
9594
**kwargs,
9695
):
9796
super().__init__(**kwargs)
@@ -102,7 +101,6 @@ def __init__(
102101
self.migration_type = migration_type
103102
self.table_mappings = table_mappings
104103
self.create_task_kwargs = create_task_kwargs or {}
105-
self.aws_conn_id = aws_conn_id
106104

107105
def execute(self, context: Context):
108106
"""

providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,6 @@ def __init__(
313313
description: str = "AWS Glue Data Quality Rule Set With Airflow",
314314
update_rule_set: bool = False,
315315
data_quality_ruleset_kwargs: dict | None = None,
316-
aws_conn_id: str | None = "aws_default",
317316
**kwargs,
318317
):
319318
super().__init__(**kwargs)
@@ -322,7 +321,6 @@ def __init__(
322321
self.description = description
323322
self.update_rule_set = update_rule_set
324323
self.data_quality_ruleset_kwargs = data_quality_ruleset_kwargs or {}
325-
self.aws_conn_id = aws_conn_id
326324

327325
def validate_inputs(self) -> None:
328326
if not self.ruleset.startswith("Rules") or not self.ruleset.endswith("]"):
@@ -421,7 +419,6 @@ def __init__(
421419
waiter_delay: int = 60,
422420
waiter_max_attempts: int = 20,
423421
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
424-
aws_conn_id: str | None = "aws_default",
425422
**kwargs,
426423
):
427424
super().__init__(**kwargs)
@@ -437,7 +434,6 @@ def __init__(
437434
self.waiter_delay = waiter_delay
438435
self.waiter_max_attempts = waiter_max_attempts
439436
self.deferrable = deferrable
440-
self.aws_conn_id = aws_conn_id
441437

442438
def validate_inputs(self) -> None:
443439
glue_table = self.datasource.get("GlueTable", {})
@@ -584,7 +580,6 @@ def __init__(
584580
waiter_delay: int = 60,
585581
waiter_max_attempts: int = 20,
586582
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
587-
aws_conn_id: str | None = "aws_default",
588583
**kwargs,
589584
):
590585
super().__init__(**kwargs)
@@ -598,7 +593,6 @@ def __init__(
598593
self.waiter_delay = waiter_delay
599594
self.waiter_max_attempts = waiter_max_attempts
600595
self.deferrable = deferrable
601-
self.aws_conn_id = aws_conn_id
602596

603597
def execute(self, context: Context) -> str:
604598
glue_table = self.datasource.get("GlueTable", {})

providers/amazon/tests/unit/amazon/aws/operators/test_cloud_formation.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,23 @@ def test_template_fields(self):
103103

104104
validate_template_fields(op)
105105

106+
def test_overwritten_conn_passed_to_hook(self):
107+
OVERWRITTEN_CONN = "new-conn-id"
108+
op = CloudFormationCreateStackOperator(
109+
task_id="cf_create_stack_pass_conn",
110+
stack_name="fake-stack",
111+
cloudformation_parameters={},
112+
aws_conn_id=OVERWRITTEN_CONN,
113+
)
114+
assert op.hook.aws_conn_id == OVERWRITTEN_CONN
115+
116+
def test_default_conn_passed_to_hook(self):
117+
DEFAULT_CONN = "aws_default"
118+
op = CloudFormationCreateStackOperator(
119+
task_id="cf_create_stack_pass_default_conn", stack_name="fake-stack", cloudformation_parameters={}
120+
)
121+
assert op.hook.aws_conn_id == DEFAULT_CONN
122+
106123

107124
class TestCloudFormationDeleteStackOperator:
108125
def test_init(self):

providers/amazon/tests/unit/amazon/aws/operators/test_comprehend.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,29 @@ def test_initialize_comprehend_base_operator_hook(self, comprehend_base_operator
8080
assert comprehend_base_op.client == mocked_client
8181
comprehend_base_operator_mock_hook.assert_called_once()
8282

83+
def test_overwritten_conn_passed_to_hook(self):
84+
OVERWRITTEN_CONN = "new-conn-id"
85+
op = ComprehendBaseOperator(
86+
task_id="comprehend_base_operator",
87+
input_data_config=INPUT_DATA_CONFIG,
88+
output_data_config=OUTPUT_DATA_CONFIG,
89+
language_code=LANGUAGE_CODE,
90+
data_access_role_arn=ROLE_ARN,
91+
aws_conn_id=OVERWRITTEN_CONN,
92+
)
93+
assert op.hook.aws_conn_id == OVERWRITTEN_CONN
94+
95+
def test_default_conn_passed_to_hook(self):
96+
DEFAULT_CONN = "aws_default"
97+
op = ComprehendBaseOperator(
98+
task_id="comprehend_base_operator",
99+
input_data_config=INPUT_DATA_CONFIG,
100+
output_data_config=OUTPUT_DATA_CONFIG,
101+
language_code=LANGUAGE_CODE,
102+
data_access_role_arn=ROLE_ARN,
103+
)
104+
assert op.hook.aws_conn_id == DEFAULT_CONN
105+
83106

84107
class TestComprehendStartPiiEntitiesDetectionJobOperator:
85108
JOB_ID = "random-job-id-1234567"

providers/amazon/tests/unit/amazon/aws/operators/test_dms.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,27 @@ def test_template_fields(self):
149149

150150
validate_template_fields(op)
151151

152+
def test_overwritten_conn_passed_to_hook(self):
153+
OVERWRITTEN_CONN = "new-conn-id"
154+
op = DmsCreateTaskOperator(
155+
task_id="dms_create_task_operator",
156+
**self.TASK_DATA,
157+
aws_conn_id=OVERWRITTEN_CONN,
158+
verify=True,
159+
botocore_config={"read_timeout": 42},
160+
)
161+
assert op.hook.aws_conn_id == OVERWRITTEN_CONN
162+
163+
def test_default_conn_passed_to_hook(self):
164+
DEFAULT_CONN = "aws_default"
165+
op = DmsCreateTaskOperator(
166+
task_id="dms_create_task_operator",
167+
**self.TASK_DATA,
168+
verify=True,
169+
botocore_config={"read_timeout": 42},
170+
)
171+
assert op.hook.aws_conn_id == DEFAULT_CONN
172+
152173

153174
class TestDmsDeleteTaskOperator:
154175
TASK_DATA = {

providers/amazon/tests/unit/amazon/aws/operators/test_glue.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,25 @@ def test_template_fields(self):
408408
)
409409
validate_template_fields(operator)
410410

411+
def test_overwritten_conn_passed_to_hook(self):
412+
OVERWRITTEN_CONN = "new-conn-id"
413+
op = GlueJobOperator(
414+
task_id=TASK_ID,
415+
aws_conn_id=OVERWRITTEN_CONN,
416+
iam_role_name="role_arn",
417+
replace_script_file=True,
418+
)
419+
assert op.hook.aws_conn_id == OVERWRITTEN_CONN
420+
421+
def test_default_conn_passed_to_hook(self):
422+
DEFAULT_CONN = "aws_default"
423+
op = GlueJobOperator(
424+
task_id=TASK_ID,
425+
iam_role_name="role_arn",
426+
replace_script_file=True,
427+
)
428+
assert op.hook.aws_conn_id == DEFAULT_CONN
429+
411430

412431
class TestGlueDataQualityOperator:
413432
RULE_SET_NAME = "TestRuleSet"
@@ -542,6 +561,23 @@ def test_template_fields(self):
542561
)
543562
validate_template_fields(operator)
544563

564+
def test_overwritten_conn_passed_to_hook(self):
565+
OVERWRITTEN_CONN = "new-conn-id"
566+
op = GlueDataQualityOperator(
567+
task_id="test_overwritten_conn_passed_to_hook",
568+
name=self.RULE_SET_NAME,
569+
ruleset=self.RULE_SET,
570+
aws_conn_id=OVERWRITTEN_CONN,
571+
)
572+
assert op.hook.aws_conn_id == OVERWRITTEN_CONN
573+
574+
def test_default_conn_passed_to_hook(self):
575+
DEFAULT_CONN = "aws_default"
576+
op = GlueDataQualityOperator(
577+
task_id="test_default_conn_passed_to_hook", name=self.RULE_SET_NAME, ruleset=self.RULE_SET
578+
)
579+
assert op.hook.aws_conn_id == DEFAULT_CONN
580+
545581

546582
class TestGlueDataQualityRuleSetEvaluationRunOperator:
547583
RUN_ID = "1234567890"
@@ -648,6 +684,29 @@ def test_start_data_quality_ruleset_evaluation_run_wait_combinations(
648684
def test_template_fields(self):
649685
validate_template_fields(self.operator)
650686

687+
def test_overwritten_conn_passed_to_hook(self):
688+
OVERWRITTEN_CONN = "new-conn-id"
689+
op = GlueDataQualityRuleSetEvaluationRunOperator(
690+
task_id="test_overwritten_conn_passed_to_hook",
691+
datasource=self.DATA_SOURCE,
692+
role=self.ROLE,
693+
rule_set_names=self.RULE_SET_NAMES,
694+
show_results=False,
695+
aws_conn_id=OVERWRITTEN_CONN,
696+
)
697+
assert op.hook.aws_conn_id == OVERWRITTEN_CONN
698+
699+
def test_default_conn_passed_to_hook(self):
700+
DEFAULT_CONN = "aws_default"
701+
op = GlueDataQualityRuleSetEvaluationRunOperator(
702+
task_id="test_default_conn_passed_to_hook",
703+
datasource=self.DATA_SOURCE,
704+
role=self.ROLE,
705+
rule_set_names=self.RULE_SET_NAMES,
706+
show_results=False,
707+
)
708+
assert op.hook.aws_conn_id == DEFAULT_CONN
709+
651710

652711
class TestGlueDataQualityRuleRecommendationRunOperator:
653712
RUN_ID = "1234567890"
@@ -756,3 +815,28 @@ def test_start_data_quality_rule_recommendation_run_wait_combinations(
756815

757816
def test_template_fields(self):
758817
validate_template_fields(self.operator)
818+
819+
def test_overwritten_conn_passed_to_hook(self):
820+
OVERWRITTEN_CONN = "new-conn-id"
821+
op = GlueDataQualityRuleRecommendationRunOperator(
822+
task_id="test_overwritten_conn_passed_to_hook",
823+
datasource=self.DATA_SOURCE,
824+
role=self.ROLE,
825+
number_of_workers=10,
826+
timeout=1000,
827+
recommendation_run_kwargs={"CreatedRulesetName": "test-ruleset"},
828+
aws_conn_id=OVERWRITTEN_CONN,
829+
)
830+
assert op.hook.aws_conn_id == OVERWRITTEN_CONN
831+
832+
def test_default_conn_passed_to_hook(self):
833+
DEFAULT_CONN = "aws_default"
834+
op = GlueDataQualityRuleRecommendationRunOperator(
835+
task_id="test_default_conn_passed_to_hook",
836+
datasource=self.DATA_SOURCE,
837+
role=self.ROLE,
838+
number_of_workers=10,
839+
timeout=1000,
840+
recommendation_run_kwargs={"CreatedRulesetName": "test-ruleset"},
841+
)
842+
assert op.hook.aws_conn_id == DEFAULT_CONN

providers/amazon/tests/unit/amazon/aws/operators/test_rds.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,9 @@ def test_overwritten_conn_passed_to_hook(self):
181181
)
182182
assert op.hook.aws_conn_id == OVERWRITTEN_CONN
183183

184-
def test_no_conn_passed_to_hook(self):
184+
def test_default_conn_passed_to_hook(self):
185185
DEFAULT_CONN = "aws_default"
186-
op = RdsBaseOperator(task_id="test_no_conn_passed_to_hook_task", dag=self.dag)
186+
op = RdsBaseOperator(task_id="test_default_conn_passed_to_hook_task", dag=self.dag)
187187
assert op.hook.aws_conn_id == DEFAULT_CONN
188188

189189

0 commit comments

Comments
 (0)