import json import boto3 import time # The account id sharing snapshots with you SOURCE_ACCOUNT = "987698769876" # The account id this is running in TARGET_ACCOUNT = "123412341234" # Mapping of KMS keys -- keys are the key arn in the source account, values # are the key arn in this account. KMS_KEYS = { "arn:aws:kms:ca-central-1:987698769876:key/abcd-abcd-abcd": "arn:aws:kms:ca-central-1:123412341234:key/qrst-qrst-qrst" } # Configure retention period for deleted snapshots. If the string is present # in the snapshot description, then the retention time will be the number of # seconds given in the value. If multiple strings match, which one is chosen # is undefined. RETENTION = { "hour_snapshot": 0, "day_snapshot": 86400, "week_snapshot": 259200, "month_snapshot": 259200, "year_snapshot": 259200 } # Name of the tag on the snapshots to store the source snapshot id in. SOURCE_TAG = "SourceSnapshotId" # Name of the tag on the snapshots to store the retention period in. RETENTION_TAG = "RetentionPeriod" # Region the source snapshots are in. REGION = "ca-central-1" print("START") ec2 = boto3.client('ec2') def lambda_handler(event, context): print("Fetching snapshots for source account " + SOURCE_ACCOUNT) source_snapshots = fetch_snapshots(SOURCE_ACCOUNT) print("Fetching snapshots for target account: " + TARGET_ACCOUNT) target_snapshots = fetch_snapshots(TARGET_ACCOUNT) source_snapshot_index = {} target_snapshot_index = {} target_snapshots_mapped = {} target_snapshots_orphaned = {} results = [] for snapshot in source_snapshots: source_snapshot_index[snapshot['SnapshotId']] = snapshot for snapshot in target_snapshots: target_snapshot_index[snapshot['SnapshotId']] = snapshot for tag in snapshot['Tags']: if tag['Key'] == SOURCE_TAG: source_snapshot_id = tag['Value'] if source_snapshot_id in source_snapshot_index.keys(): target_snapshots_mapped[source_snapshot_id] = snapshot else: target_snapshots_orphaned[snapshot['SnapshotId']] = snapshot for source_snapshot_id, source_snapshot in source_snapshot_index.items(): if source_snapshot_id in target_snapshots_mapped.keys(): results.append(handle_snapshot(source_snapshot, target_snapshots_mapped[source_snapshot_id])) else: results.append(handle_snapshot(source_snapshot, None)) for target_snapshot_id in target_snapshots_orphaned: results.append(handle_snapshot(None, target_snapshots_orphaned[target_snapshot_id])) results = list(filter(None, results)) return { 'statusCode': 200, 'body': json.dumps(results) } def fetch_snapshots(account_id): return ec2.describe_snapshots( Filters = [ { 'Name': 'owner-id', 'Values': [ account_id ] } ] )['Snapshots'] def handle_snapshot(source_snapshot, target_snapshot): if source_snapshot is not None and target_snapshot is not None: # Source and target exist, nothing to do! pass elif source_snapshot is not None and target_snapshot is None: # Source exists, no target. Copy! print(source_snapshot['SnapshotId'] + " copying") # Quick guard -- make sure the snapshot is complete first. if source_snapshot['Progress'] != '100%': return [source_snapshot['SnapshotId'], None, "Source snapshot incomplete. Doing nothing."] copy_response = ec2.copy_snapshot( Description = '/'.join([ source_snapshot['OwnerId'], source_snapshot['VolumeId'], str(source_snapshot['StartTime']), source_snapshot['Description'] ]), Encrypted = True, KmsKeyId = KMS_KEYS[source_snapshot['KmsKeyId']], SourceRegion = REGION, SourceSnapshotId = source_snapshot['SnapshotId'], TagSpecifications = [ { 'ResourceType': 'snapshot', 'Tags': [ { "Key": SOURCE_TAG, "Value": source_snapshot['SnapshotId'] } ] } ] ) return [source_snapshot['SnapshotId'], copy_response['SnapshotId'], "Copied snapshot"] elif source_snapshot is None and target_snapshot is not None: # Check if there's already a retention tag retention = None for tag in target_snapshot['Tags']: if tag['Key'] == RETENTION_TAG: retention = float(tag['Value']) break if retention is None: # Mark target with retention tag print(target_snapshot['SnapshotId'] + ": marked for removal") retention = time.time() description = target_snapshot['Description'] for search_str, retention_seconds in RETENTION.items(): if search_str in description: retention = retention + retention_seconds break ec2.create_tags( Resources=[ target_snapshot['SnapshotId'] ], Tags=[ { 'Key': RETENTION_TAG, 'Value':str(retention) } ] ) return [None, target_snapshot['SnapshotId'], "Source removed. Marking for deletion."] elif retention < time.time(): print(target_snapshot['SnapshotId'] + ": retention time passed, removing") ec2.delete_snapshot( SnapshotId = target_snapshot['SnapshotId'] ) return [None, target_snapshot['SnapshotId'], "Retention period passed. Deleted."]